mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
# What does this PR do? - Changed all models to extract `embed_tokens` in order to enable llava to separately call the embeddings and the core model layers. - Added VlmCausalLM to inherit from FlashMistral in order to be maximally supported. The only added logics sits on top and parses images into pixel values, preallocates input_ids space for the image embeddings, and passes them for the model. - Added Clip for the vision tower. - Didn't add flash for the vision tower since there's no padding anyway. - Added heuristic (potentially incomplete) to calculate number of features *before* calculating the clip patches (allows for easier logic reuse of the LLM under the hood). Still needs to be done: - [x] Implement the image parsing in the controller side, to avoid downloading n times per TP shard and also refusing requests too large early and avoid issues where the truncation actually truncates the image. - [ ] Make sure it works with quantization properly. - [x] Make sure it works with TP>1 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
828 lines
31 KiB
Python
828 lines
31 KiB
Python
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_attn_mask_utils import (
|
|
_create_4d_causal_attention_mask,
|
|
_prepare_4d_attention_mask,
|
|
)
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPooling,
|
|
ImageClassifierOutput,
|
|
)
|
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
|
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelEmbedding,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
)
|
|
|
|
|
|
class CLIPVisionEmbeddings(nn.Module):
|
|
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
# TODO Should we TP this ?
|
|
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
bias=False,
|
|
)
|
|
self.patch_embedding.weight = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches + 1
|
|
self.position_embedding = TensorParallelEmbedding(
|
|
prefix=f"{prefix}.position_embedding", weights=weights
|
|
)
|
|
self.register_buffer(
|
|
"position_ids",
|
|
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
|
persistent=False,
|
|
)
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
batch_size = pixel_values.shape[0]
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(
|
|
pixel_values.to(dtype=target_dtype)
|
|
) # shape = [*, width, grid, grid]
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
return embeddings
|
|
|
|
|
|
class CLIPTextEmbeddings(nn.Module):
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
|
self.position_embedding = nn.Embedding(
|
|
config.max_position_embeddings, embed_dim
|
|
)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids",
|
|
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
|
persistent=False,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
) -> torch.Tensor:
|
|
seq_length = (
|
|
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, :seq_length]
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.token_embedding(input_ids)
|
|
|
|
position_embeddings = self.position_embedding(position_ids)
|
|
embeddings = inputs_embeds + position_embeddings
|
|
|
|
return embeddings
|
|
|
|
|
|
class CLIPAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_size = self.embed_dim // self.num_heads
|
|
if self.head_size * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.num_heads = self.num_heads // weights.process_group.size()
|
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
|
self.scale = self.head_size**-0.5
|
|
self.dropout = config.attention_dropout
|
|
|
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
dim=0,
|
|
weights=weights,
|
|
bias=True,
|
|
)
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=True,
|
|
)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return (
|
|
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# get query proj
|
|
|
|
qkv = self.qkv(hidden_states)
|
|
query_states, key_states, value_states = qkv.split(
|
|
[
|
|
self.head_size * self.num_heads,
|
|
]
|
|
* 3,
|
|
dim=2,
|
|
)
|
|
query_states = query_states * self.scale
|
|
key_states = self._shape(key_states, -1, bsz)
|
|
value_states = self._shape(value_states, -1, bsz)
|
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
key_states = key_states.view(*proj_shape)
|
|
value_states = value_states.view(*proj_shape)
|
|
|
|
src_len = key_states.size(1)
|
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
# apply the causal_attention_mask first
|
|
if causal_attention_mask is not None:
|
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
|
f" {causal_attention_mask.size()}"
|
|
)
|
|
attn_weights = (
|
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
+ causal_attention_mask
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = (
|
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
+ attention_mask
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
attn_probs = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=self.training
|
|
)
|
|
|
|
attn_output = torch.bmm(attn_probs, value_states)
|
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, None
|
|
|
|
|
|
class CLIPMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = TensorParallelColumnLinear.load(
|
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
|
)
|
|
self.fc2 = TensorParallelRowLinear.load(
|
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class CLIPEncoderLayer(nn.Module):
|
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = CLIPAttention(
|
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
)
|
|
self.layer_norm1 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
self.layer_norm2 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
):
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
`(config.encoder_attention_heads,)`.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CLIPPreTrainedModel(nn.Module):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = CLIPConfig
|
|
base_model_prefix = "clip"
|
|
supports_gradient_checkpointing = True
|
|
|
|
|
|
CLIP_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
"""
|
|
|
|
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
|
"""
|
|
|
|
CLIP_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
|
return_loss (`bool`, *optional*):
|
|
Whether or not to return the contrastive loss.
|
|
"""
|
|
|
|
|
|
class CLIPEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
[`CLIPEncoderLayer`].
|
|
|
|
Args:
|
|
config: CLIPConfig
|
|
"""
|
|
|
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
CLIPEncoderLayer(
|
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
"""
|
|
|
|
hidden_states = inputs_embeds
|
|
for idx, encoder_layer in enumerate(self.layers):
|
|
hidden_states = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
causal_attention_mask,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CLIPTextTransformer(nn.Module):
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
self.embeddings = CLIPTextEmbeddings(config)
|
|
self.encoder = CLIPEncoder(
|
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
|
)
|
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
|
|
# For `pooled_output` computation
|
|
self.eos_token_id = config.eos_token_id
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
"""
|
|
if input_ids is None:
|
|
raise ValueError("You have to specify input_ids")
|
|
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
|
|
|
# CLIP's text model uses causal mask, prepare it here.
|
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
|
causal_attention_mask = _create_4d_causal_attention_mask(
|
|
input_shape, hidden_states.dtype, device=hidden_states.device
|
|
)
|
|
# expand attention_mask
|
|
if attention_mask is not None:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(
|
|
attention_mask, hidden_states.dtype
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs[0]
|
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
|
|
|
if self.eos_token_id == 2:
|
|
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
|
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
|
# ------------------------------------------------------------
|
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
|
pooled_output = last_hidden_state[
|
|
torch.arange(
|
|
last_hidden_state.shape[0], device=last_hidden_state.device
|
|
),
|
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
|
|
dim=-1
|
|
),
|
|
]
|
|
else:
|
|
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
|
pooled_output = last_hidden_state[
|
|
torch.arange(
|
|
last_hidden_state.shape[0], device=last_hidden_state.device
|
|
),
|
|
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
|
(
|
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
|
|
== self.eos_token_id
|
|
)
|
|
.int()
|
|
.argmax(dim=-1),
|
|
]
|
|
|
|
return last_hidden_state
|
|
|
|
|
|
class CLIPTextModel(CLIPPreTrainedModel):
|
|
config_class = CLIPTextConfig
|
|
|
|
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
|
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__(config)
|
|
self.text_model = CLIPTextTransformer(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, CLIPTextModel
|
|
|
|
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> last_hidden_state = outputs.last_hidden_state
|
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
|
```"""
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
return self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
|
|
class CLIPVisionTransformer(nn.Module):
|
|
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.embeddings = CLIPVisionEmbeddings(
|
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
|
)
|
|
self.pre_layrnorm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
|
)
|
|
self.encoder = CLIPEncoder(
|
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
|
)
|
|
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
"""
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
hidden_states = self.embeddings(pixel_values)
|
|
hidden_states = self.pre_layrnorm(hidden_states)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
)
|
|
last_hidden_state = encoder_outputs
|
|
# pooled_output = last_hidden_state[:, 0, :]
|
|
# pooled_output = self.post_layernorm(pooled_output)
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=last_hidden_state,
|
|
# pooler_output=pooled_output,
|
|
# hidden_states=encoder_outputs,
|
|
)
|
|
|
|
|
|
class CLIPVisionModel(CLIPPreTrainedModel):
|
|
config_class = CLIPVisionConfig
|
|
main_input_name = "pixel_values"
|
|
_no_split_modules = ["CLIPEncoderLayer"]
|
|
|
|
def __init__(self, config: CLIPVisionConfig):
|
|
super().__init__(config)
|
|
self.vision_model = CLIPVisionTransformer(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.vision_model.embeddings.patch_embedding
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, CLIPVisionModel
|
|
|
|
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> last_hidden_state = outputs.last_hidden_state
|
|
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
|
```"""
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
return self.vision_model(
|
|
pixel_values=pixel_values,
|
|
)
|
|
|
|
|
|
class CLIPModel(nn.Module):
|
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
|
super().__init__()
|
|
text_config = config.text_config
|
|
vision_config = config.vision_config
|
|
|
|
self.projection_dim = config.projection_dim
|
|
self.text_embed_dim = text_config.hidden_size
|
|
self.vision_embed_dim = vision_config.hidden_size
|
|
|
|
self.text_model = CLIPTextTransformer(text_config)
|
|
self.vision_model = CLIPVisionTransformer(vision_config)
|
|
|
|
self.visual_projection = nn.Linear(
|
|
self.vision_embed_dim, self.projection_dim, bias=False
|
|
)
|
|
self.text_projection = nn.Linear(
|
|
self.text_embed_dim, self.projection_dim, bias=False
|
|
)
|
|
self.logit_scale = nn.Parameter(
|
|
torch.tensor(self.config.logit_scale_init_value)
|
|
)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_text_features(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
) -> torch.FloatTensor:
|
|
r"""
|
|
Returns:
|
|
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
|
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, CLIPModel
|
|
|
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
|
>>> text_features = model.get_text_features(**inputs)
|
|
```"""
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
pooled_output = text_outputs[1]
|
|
text_features = self.text_projection(pooled_output)
|
|
|
|
return text_features
|
|
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
r"""
|
|
Returns:
|
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, CLIPModel
|
|
|
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
|
|
>>> image_features = model.get_image_features(**inputs)
|
|
```"""
|
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
)
|
|
|
|
pooled_output = vision_outputs[1] # pooled_output
|
|
image_features = self.visual_projection(pooled_output)
|
|
|
|
return image_features
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, CLIPModel
|
|
|
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(
|
|
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
|
... )
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
|
```"""
|
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
image_embeds = vision_outputs[1]
|
|
image_embeds = self.visual_projection(image_embeds)
|
|
|
|
text_embeds = text_outputs[1]
|
|
text_embeds = self.text_projection(text_embeds)
|
|
|
|
# normalized features
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
|
|
# cosine similarity as logits
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
|
logits_per_image = logits_per_text.t()
|
|
|
|
return logits_per_image, logits_per_text
|