mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Tmp dump (running on images hardcoded size.)
This commit is contained in:
parent
5f4b395480
commit
df4c700828
@ -8,6 +8,11 @@ from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
@ -147,7 +152,6 @@ class CLIPAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -267,7 +271,6 @@ class CLIPEncoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -275,9 +278,6 @@ class CLIPEncoderLayer(nn.Module):
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
@ -286,7 +286,6 @@ class CLIPEncoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -346,14 +345,6 @@ CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
@ -361,14 +352,6 @@ CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
@ -398,14 +381,6 @@ CLIP_INPUTS_DOCSTRING = r"""
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@ -435,9 +410,6 @@ class CLIPEncoder(nn.Module):
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -459,43 +431,16 @@ class CLIPEncoder(nn.Module):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -518,28 +463,11 @@ class CLIPTextTransformer(nn.Module):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify input_ids")
|
||||
|
||||
@ -564,9 +492,6 @@ class CLIPTextTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
@ -621,9 +546,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@ -650,9 +572,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -671,37 +590,16 @@ class CLIPVisionTransformer(nn.Module):
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
@ -710,23 +608,15 @@ class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
last_hidden_state = encoder_outputs
|
||||
# pooled_output = last_hidden_state[:, 0, :]
|
||||
# pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
@ -747,9 +637,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@ -779,9 +666,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -816,9 +700,6 @@ class CLIPModel(nn.Module):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
@ -836,28 +717,10 @@ class CLIPModel(nn.Module):
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
@ -868,9 +731,6 @@ class CLIPModel(nn.Module):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
@ -895,25 +755,8 @@ class CLIPModel(nn.Module):
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
@ -927,10 +770,6 @@ class CLIPModel(nn.Module):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@ -957,24 +796,8 @@ class CLIPModel(nn.Module):
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -982,8 +805,6 @@ class CLIPModel(nn.Module):
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
@ -376,7 +376,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -385,7 +385,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
@ -437,8 +437,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -22,7 +22,11 @@ from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from transformers import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
@ -83,15 +87,15 @@ def unpad_image(tensor, original_size):
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
@ -135,13 +139,19 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
@ -158,114 +168,17 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(
|
||||
input_ids[:, -1] == torch.tensor(self.pad_token_id)
|
||||
)
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (
|
||||
num_special_image_tokens.max() * (num_image_patches - 1)
|
||||
) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(
|
||||
input_ids != self.config.image_token_index
|
||||
)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
|
||||
- 1
|
||||
)
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
embed_dim,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
dtype=attention_mask.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim),
|
||||
self.config.ignore_index,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
|
||||
:, None
|
||||
].to(target_device)
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = (
|
||||
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
||||
(final_attention_mask == 0), 1
|
||||
)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -282,15 +195,11 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
num_special_image_tokens = (
|
||||
input_ids == self.config.image_token_index
|
||||
).sum()
|
||||
assert num_special_image_tokens == len(
|
||||
pixel_values
|
||||
), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
@ -299,9 +208,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[
|
||||
self.config.vision_feature_layer
|
||||
]
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
@ -368,26 +277,21 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds, attention_mask, labels, position_ids = (
|
||||
self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
)
|
||||
if labels is None:
|
||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(
|
||||
torch.long
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
logits = self.language_model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
lm_head_indices,
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
)
|
||||
return logits
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -7,7 +7,6 @@ import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
@ -52,7 +52,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
elif chunk["type"] == "image":
|
||||
full_text += "<image>"
|
||||
full_text += "<image>" * 2928
|
||||
images.append(chunk["content"])
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
Loading…
Reference in New Issue
Block a user