mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Tmp dump (running on images hardcoded size.)
This commit is contained in:
parent
5f4b395480
commit
df4c700828
@ -8,6 +8,11 @@ from transformers.modeling_attn_mask_utils import (
|
|||||||
_create_4d_causal_attention_mask,
|
_create_4d_causal_attention_mask,
|
||||||
_prepare_4d_attention_mask,
|
_prepare_4d_attention_mask,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
BaseModelOutputWithPooling,
|
||||||
|
ImageClassifierOutput,
|
||||||
|
)
|
||||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
@ -147,7 +152,6 @@ class CLIPAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
@ -267,7 +271,6 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
causal_attention_mask: torch.Tensor,
|
causal_attention_mask: torch.Tensor,
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -275,9 +278,6 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
`(config.encoder_attention_heads,)`.
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -286,7 +286,6 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
causal_attention_mask=causal_attention_mask,
|
causal_attention_mask=causal_attention_mask,
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -346,14 +345,6 @@ CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
|||||||
config.max_position_embeddings - 1]`.
|
config.max_position_embeddings - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||||
@ -361,14 +352,6 @@ CLIP_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLIP_INPUTS_DOCSTRING = r"""
|
CLIP_INPUTS_DOCSTRING = r"""
|
||||||
@ -398,14 +381,6 @@ CLIP_INPUTS_DOCSTRING = r"""
|
|||||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||||
return_loss (`bool`, *optional*):
|
return_loss (`bool`, *optional*):
|
||||||
Whether or not to return the contrastive loss.
|
Whether or not to return the contrastive loss.
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -435,9 +410,6 @@ class CLIPEncoder(nn.Module):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -459,43 +431,16 @@ class CLIPEncoder(nn.Module):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
||||||
for more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
"""
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
|
||||||
all_attentions = () if output_attentions else None
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
layer_outputs = encoder_layer(
|
hidden_states = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
causal_attention_mask,
|
causal_attention_mask,
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -518,28 +463,11 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
raise ValueError("You have to specify input_ids")
|
raise ValueError("You have to specify input_ids")
|
||||||
|
|
||||||
@ -564,9 +492,6 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
causal_attention_mask=causal_attention_mask,
|
causal_attention_mask=causal_attention_mask,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
last_hidden_state = encoder_outputs[0]
|
last_hidden_state = encoder_outputs[0]
|
||||||
@ -621,9 +546,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
|||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -650,9 +572,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -671,37 +590,16 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
self.encoder = CLIPEncoder(
|
self.encoder = CLIPEncoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||||
prefix=f"{prefix}.post_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
@ -710,23 +608,15 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
last_hidden_state = encoder_outputs[0]
|
# pooled_output = last_hidden_state[:, 0, :]
|
||||||
pooled_output = last_hidden_state[:, 0, :]
|
# pooled_output = self.post_layernorm(pooled_output)
|
||||||
pooled_output = self.post_layernorm(pooled_output)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
pooler_output=pooled_output,
|
# pooler_output=pooled_output,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
# hidden_states=encoder_outputs,
|
||||||
attentions=encoder_outputs.attentions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -747,9 +637,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -779,9 +666,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
|||||||
|
|
||||||
return self.vision_model(
|
return self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -816,9 +700,6 @@ class CLIPModel(nn.Module):
|
|||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -836,28 +717,10 @@ class CLIPModel(nn.Module):
|
|||||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||||
>>> text_features = model.get_text_features(**inputs)
|
>>> text_features = model.get_text_features(**inputs)
|
||||||
```"""
|
```"""
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
text_outputs = self.text_model(
|
text_outputs = self.text_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_output = text_outputs[1]
|
pooled_output = text_outputs[1]
|
||||||
@ -868,9 +731,6 @@ class CLIPModel(nn.Module):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -895,25 +755,8 @@ class CLIPModel(nn.Module):
|
|||||||
>>> image_features = model.get_image_features(**inputs)
|
>>> image_features = model.get_image_features(**inputs)
|
||||||
```"""
|
```"""
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
vision_outputs = self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_output = vision_outputs[1] # pooled_output
|
pooled_output = vision_outputs[1] # pooled_output
|
||||||
@ -927,10 +770,6 @@ class CLIPModel(nn.Module):
|
|||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
return_loss: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -957,24 +796,8 @@ class CLIPModel(nn.Module):
|
|||||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||||
```"""
|
```"""
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
vision_outputs = self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -982,8 +805,6 @@ class CLIPModel(nn.Module):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -385,7 +385,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -437,8 +437,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
@ -22,7 +22,11 @@ from torch import nn
|
|||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from transformers import AutoModel, AutoModelForCausalLM
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
@ -83,15 +87,15 @@ def unpad_image(tensor, original_size):
|
|||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||||
class LlavaNextMultiModalProjector(nn.Module):
|
class LlavaNextMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.linear_1 = nn.Linear(
|
self.linear_1 = TensorParallelColumnLinear.load(
|
||||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||||
)
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(
|
self.linear_2 = TensorParallelRowLinear.load(
|
||||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
@ -135,13 +139,19 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
|
vision_config = config.vision_config
|
||||||
|
# Instead of selecting in hidden_states[-2].
|
||||||
|
# Instead compute only the n -2 + 1 layers and don't pool
|
||||||
|
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
self.image_newline = weights.get_tensor("image_newline")
|
self.image_newline = weights.get_tensor("image_newline")
|
||||||
|
|
||||||
@ -158,114 +168,17 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
|
||||||
def _merge_input_ids_with_image_features(
|
def _merge_input_ids_with_image_features(
|
||||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
):
|
):
|
||||||
num_images, num_image_patches, embed_dim = image_features.shape
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
batch_size, sequence_length = input_ids.shape
|
mask = input_ids == self.config.image_token_index
|
||||||
left_padding = not torch.sum(
|
# Let's pray we have enabled enough slots !
|
||||||
input_ids[:, -1] == torch.tensor(self.pad_token_id)
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
)
|
return inputs_embeds
|
||||||
# 1. Create a mask to know where special image tokens are
|
|
||||||
special_image_token_mask = input_ids == self.config.image_token_index
|
|
||||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
|
||||||
# Compute the maximum embed dimension
|
|
||||||
max_embed_dim = (
|
|
||||||
num_special_image_tokens.max() * (num_image_patches - 1)
|
|
||||||
) + sequence_length
|
|
||||||
batch_indices, non_image_indices = torch.where(
|
|
||||||
input_ids != self.config.image_token_index
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Compute the positions where text should be written
|
|
||||||
# Calculate new positions for text tokens in merged image-text sequence.
|
|
||||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
|
||||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
|
||||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
|
||||||
new_token_positions = (
|
|
||||||
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
|
||||||
if left_padding:
|
|
||||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
|
||||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
|
||||||
|
|
||||||
# 3. Create the full embedding, already padded to the maximum position
|
|
||||||
final_embedding = torch.zeros(
|
|
||||||
batch_size,
|
|
||||||
max_embed_dim,
|
|
||||||
embed_dim,
|
|
||||||
dtype=inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
final_attention_mask = torch.zeros(
|
|
||||||
batch_size,
|
|
||||||
max_embed_dim,
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
final_labels = torch.full(
|
|
||||||
(batch_size, max_embed_dim),
|
|
||||||
self.config.ignore_index,
|
|
||||||
dtype=input_ids.dtype,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
|
||||||
# set the corresponding tensors into their correct target device.
|
|
||||||
target_device = inputs_embeds.device
|
|
||||||
batch_indices, non_image_indices, text_to_overwrite = (
|
|
||||||
batch_indices.to(target_device),
|
|
||||||
non_image_indices.to(target_device),
|
|
||||||
text_to_overwrite.to(target_device),
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask.to(target_device)
|
|
||||||
|
|
||||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
|
||||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
|
||||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
|
||||||
batch_indices, non_image_indices
|
|
||||||
]
|
|
||||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
|
||||||
batch_indices, non_image_indices
|
|
||||||
]
|
|
||||||
if labels is not None:
|
|
||||||
final_labels[batch_indices, text_to_overwrite] = labels[
|
|
||||||
batch_indices, non_image_indices
|
|
||||||
]
|
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
|
||||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
|
|
||||||
:, None
|
|
||||||
].to(target_device)
|
|
||||||
|
|
||||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
|
||||||
raise ValueError(
|
|
||||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
|
||||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
final_embedding[image_to_overwrite] = (
|
|
||||||
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
|
||||||
)
|
|
||||||
final_attention_mask |= image_to_overwrite
|
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
|
||||||
(final_attention_mask == 0), 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
|
||||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
|
||||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
|
||||||
|
|
||||||
final_embedding[batch_indices, indices_to_mask] = 0
|
|
||||||
|
|
||||||
if labels is None:
|
|
||||||
final_labels = None
|
|
||||||
|
|
||||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -282,15 +195,11 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
num_special_image_tokens = (
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
input_ids == self.config.image_token_index
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
).sum()
|
|
||||||
assert num_special_image_tokens == len(
|
|
||||||
pixel_values
|
|
||||||
), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
|
||||||
# 1. Extract the input embeddings
|
# 1. Extract the input embeddings
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# 2. Merge text and images
|
# 2. Merge text and images
|
||||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||||
@ -299,9 +208,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
image_features = self.vision_tower(pixel_values)
|
image_features = self.vision_tower(pixel_values)
|
||||||
|
|
||||||
selected_image_feature = image_features.hidden_states[
|
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||||
self.config.vision_feature_layer
|
# Already done within the clip model
|
||||||
]
|
selected_image_feature = image_features.last_hidden_state
|
||||||
|
|
||||||
if self.config.vision_feature_select_strategy == "default":
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
selected_image_feature = selected_image_feature[:, 1:]
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
@ -368,26 +277,21 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
new_image_features.append(image_feature)
|
new_image_features.append(image_feature)
|
||||||
image_features = torch.stack(new_image_features, dim=0)
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
|
||||||
inputs_embeds, attention_mask, labels, position_ids = (
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
self._merge_input_ids_with_image_features(
|
input_ids, inputs_embeds, image_features
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if labels is None:
|
|
||||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(
|
|
||||||
torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.language_model(
|
hidden_states = self.language_model.model(
|
||||||
input_ids,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
block_tables=block_tables,
|
||||||
slots,
|
slots=slots,
|
||||||
input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices,
|
|
||||||
lm_head_indices,
|
|
||||||
)
|
)
|
||||||
return logits
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
||||||
|
@ -7,7 +7,6 @@ import numpy as np
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
@ -52,7 +52,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += chunk["content"]
|
||||||
elif chunk["type"] == "image":
|
elif chunk["type"] == "image":
|
||||||
full_text += "<image>"
|
full_text += "<image>" * 2928
|
||||||
images.append(chunk["content"])
|
images.append(chunk["content"])
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
Loading…
Reference in New Issue
Block a user