Tmp dump (running on images hardcoded size.)

This commit is contained in:
Nicolas Patry 2024-04-04 21:42:57 +00:00
parent 5f4b395480
commit df4c700828
5 changed files with 64 additions and 339 deletions

View File

@ -8,6 +8,11 @@ from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask, _create_4d_causal_attention_mask,
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
) )
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -147,7 +152,6 @@ class CLIPAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@ -267,7 +271,6 @@ class CLIPEncoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
): ):
""" """
Args: Args:
@ -275,9 +278,6 @@ class CLIPEncoderLayer(nn.Module):
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`. `(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
""" """
residual = hidden_states residual = hidden_states
@ -286,7 +286,6 @@ class CLIPEncoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask, causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -346,14 +345,6 @@ CLIP_TEXT_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]`. config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
CLIP_VISION_INPUTS_DOCSTRING = r""" CLIP_VISION_INPUTS_DOCSTRING = r"""
@ -361,14 +352,6 @@ CLIP_VISION_INPUTS_DOCSTRING = r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
CLIP_INPUTS_DOCSTRING = r""" CLIP_INPUTS_DOCSTRING = r"""
@ -398,14 +381,6 @@ CLIP_INPUTS_DOCSTRING = r"""
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*): return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss. Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@ -435,9 +410,6 @@ class CLIPEncoder(nn.Module):
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Args: Args:
@ -459,43 +431,16 @@ class CLIPEncoder(nn.Module):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
layer_outputs = encoder_layer( hidden_states = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions=output_attentions,
) )
hidden_states = layer_outputs[0]
return hidden_states return hidden_states
@ -518,28 +463,11 @@ class CLIPTextTransformer(nn.Module):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Returns: Returns:
""" """
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is None: if input_ids is None:
raise ValueError("You have to specify input_ids") raise ValueError("You have to specify input_ids")
@ -564,9 +492,6 @@ class CLIPTextTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask, causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
@ -621,9 +546,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Returns: Returns:
@ -650,9 +572,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
@ -671,37 +590,16 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
self.post_layernorm = nn.LayerNorm.load( # self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward( def forward(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Returns: Returns:
""" """
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -710,23 +608,15 @@ class CLIPVisionTransformer(nn.Module):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
last_hidden_state = encoder_outputs
last_hidden_state = encoder_outputs[0] # pooled_output = last_hidden_state[:, 0, :]
pooled_output = last_hidden_state[:, 0, :] # pooled_output = self.post_layernorm(pooled_output)
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, # pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, # hidden_states=encoder_outputs,
attentions=encoder_outputs.attentions,
) )
@ -747,9 +637,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
def forward( def forward(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Returns: Returns:
@ -779,9 +666,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
@ -816,9 +700,6 @@ class CLIPModel(nn.Module):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Returns: Returns:
@ -836,28 +717,10 @@ class CLIPModel(nn.Module):
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs) >>> text_features = model.get_text_features(**inputs)
```""" ```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model( text_outputs = self.text_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
pooled_output = text_outputs[1] pooled_output = text_outputs[1]
@ -868,9 +731,6 @@ class CLIPModel(nn.Module):
def get_image_features( def get_image_features(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Returns: Returns:
@ -895,25 +755,8 @@ class CLIPModel(nn.Module):
>>> image_features = model.get_image_features(**inputs) >>> image_features = model.get_image_features(**inputs)
```""" ```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
pooled_output = vision_outputs[1] # pooled_output pooled_output = vision_outputs[1] # pooled_output
@ -927,10 +770,6 @@ class CLIPModel(nn.Module):
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
r""" r"""
Returns: Returns:
@ -957,24 +796,8 @@ class CLIPModel(nn.Module):
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```""" ```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
@ -982,8 +805,6 @@ class CLIPModel(nn.Module):
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )

View File

@ -376,7 +376,7 @@ class FlashLlamaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -385,7 +385,7 @@ class FlashLlamaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -437,8 +437,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -22,7 +22,11 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from transformers import AutoModel, AutoModelForCausalLM
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
@ -83,15 +87,15 @@ def unpad_image(tensor, original_size):
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module): class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.linear_1 = nn.Linear( self.linear_1 = TensorParallelColumnLinear.load(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
) )
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear( self.linear_2 = TensorParallelRowLinear.load(
config.text_config.hidden_size, config.text_config.hidden_size, bias=True prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
) )
def forward(self, image_features): def forward(self, image_features):
@ -135,13 +139,19 @@ class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
vision_config.num_hidden_layers += config.vision_feature_layer + 1
self.vision_tower = load_vision_model( self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config, config=config.vision_config,
weights=weights, weights=weights,
) )
self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline") self.image_newline = weights.get_tensor("image_newline")
@ -158,114 +168,17 @@ class LlavaNextForConditionalGeneration(nn.Module):
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, labels self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
): ):
num_images, num_image_patches, embed_dim = image_features.shape """In place merges in vision_embeddings with inputs_embeds."""
batch_size, sequence_length = input_ids.shape mask = input_ids == self.config.image_token_index
left_padding = not torch.sum( # Let's pray we have enabled enough slots !
input_ids[:, -1] == torch.tensor(self.pad_token_id) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
) return inputs_embeds
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_image_tokens.max() * (num_image_patches - 1)
) + sequence_length
batch_indices, non_image_indices = torch.where(
input_ids != self.config.image_token_index
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
- 1
)
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
self.config.ignore_index,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_image_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_image_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_image_indices
]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
:, None
].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = (
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward( def forward(
self, self,
@ -282,15 +195,11 @@ class LlavaNextForConditionalGeneration(nn.Module):
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
num_special_image_tokens = ( # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
input_ids == self.config.image_token_index # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
).sum()
assert num_special_image_tokens == len(
pixel_values
), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings # 1. Extract the input embeddings
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# 2. Merge text and images # 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape num_images, num_patches, channels, height, width = pixel_values.shape
@ -299,9 +208,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
) )
image_features = self.vision_tower(pixel_values) image_features = self.vision_tower(pixel_values)
selected_image_feature = image_features.hidden_states[ # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
self.config.vision_feature_layer # Already done within the clip model
] selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default": if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:] selected_image_feature = selected_image_feature[:, 1:]
@ -368,26 +277,21 @@ class LlavaNextForConditionalGeneration(nn.Module):
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0) image_features = torch.stack(new_image_features, dim=0)
inputs_embeds, attention_mask, labels, position_ids = ( inputs_embeds = self._merge_input_ids_with_image_features(
self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_features
image_features, inputs_embeds, input_ids, attention_mask, labels
)
) )
if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(
torch.long
)
logits = self.language_model( hidden_states = self.language_model.model(
input_ids, inputs_embeds=inputs_embeds,
position_ids, position_ids=position_ids,
cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache, kv_cache=kv_cache,
block_tables, block_tables=block_tables,
slots, slots=slots,
input_lengths, input_lengths=input_lengths,
max_s, max_s=max_s,
prefill_cache_indices,
lm_head_indices,
) )
return logits if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -7,7 +7,6 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2

View File

@ -52,7 +52,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
elif chunk["type"] == "image": elif chunk["type"] == "image":
full_text += "<image>" full_text += "<image>" * 2928
images.append(chunk["content"]) images.append(chunk["content"])
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")