from typing import Optional, Tuple, Union import torch from torch import nn from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import ( _create_4d_causal_attention_mask, _prepare_4d_attention_mask, ) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, ) from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, ) class CLIPVisionEmbeddings(nn.Module): def __init__(self, prefix, config: CLIPVisionConfig, weights): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size # TODO Should we TP this ? self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding") self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.patch_embedding.weight = nn.Parameter( weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = TensorParallelEmbedding( prefix=f"{prefix}.position_embedding", weights=weights ) self.register_buffer( "position_ids", torch.arange(self.num_positions, device=weights.device).expand((1, -1)), persistent=False, ) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding( pixel_values.to(dtype=target_dtype) ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings class CLIPTextEmbeddings(nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding( config.max_position_embeddings, embed_dim ) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = ( input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] ) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, prefix, config, weights): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_size = self.embed_dim // self.num_heads if self.head_size * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.num_heads = self.num_heads // weights.process_group.size() self.embed_dim = self.embed_dim // weights.process_group.size() self.scale = self.head_size**-0.5 self.dropout = config.attention_dropout self.qkv = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=True, ) self.out_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.out_proj", weights=weights, bias=True, ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.head_size) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() # get query proj qkv = self.qkv(hidden_states) query_states, key_states, value_states = qkv.split( [ self.head_size * self.num_heads, ] * 3, dim=2, ) query_states = query_states * self.scale key_states = self._shape(key_states, -1, bsz) value_states = self._shape(value_states, -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_size) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) # apply the causal_attention_mask first if causal_attention_mask is not None: if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" f" {causal_attention_mask.size()}" ) attn_weights = ( attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = ( attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_probs = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None class CLIPMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = TensorParallelColumnLinear.load( prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True ) self.fc2 = TensorParallelRowLinear.load( prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class CLIPEncoderLayer(nn.Module): def __init__(self, prefix, config: CLIPConfig, weights): super().__init__() self.embed_dim = config.hidden_size self.self_attn = CLIPAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.layer_norm1 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps ) self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.layer_norm2 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor, ): """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class CLIPPreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = CLIPConfig base_model_prefix = "clip" supports_gradient_checkpointing = True CLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ CLIP_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) """ CLIP_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. """ CLIP_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. """ class CLIPEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`CLIPEncoderLayer`]. Args: config: CLIPConfig """ def __init__(self, prefix, config: CLIPConfig, weights): super().__init__() self.config = config self.layers = nn.ModuleList( [ CLIPEncoderLayer( prefix=f"{prefix}.layers.{i}", config=config, weights=weights ) for i in range(config.num_hidden_layers) ] ) def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Causal mask for the text model. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) """ hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states, attention_mask, causal_attention_mask, ) return hidden_states class CLIPTextTransformer(nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = CLIPTextEmbeddings(config) self.encoder = CLIPEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) # For `pooled_output` computation self.eos_token_id = config.eos_token_id def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ): r""" Returns: """ if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 causal_attention_mask = _create_4d_causal_attention_mask( input_shape, hidden_states.dtype, device=hidden_states.device ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask( attention_mask, hidden_states.dtype ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) if self.eos_token_id == 2: # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added # ------------------------------------------------------------ # text_embeds.shape = [batch_size, sequence_length, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( dim=-1 ), ] else: # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) pooled_output = last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) ( input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id ) .int() .argmax(dim=-1), ] return last_hidden_state class CLIPTextModel(CLIPPreTrainedModel): config_class = CLIPTextConfig _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] def __init__(self, config: CLIPTextConfig): super().__init__(config) self.text_model = CLIPTextTransformer(config) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ): r""" Returns: Examples: ```python >>> from transformers import AutoTokenizer, CLIPTextModel >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, ) class CLIPVisionTransformer(nn.Module): def __init__(self, prefix, config: CLIPVisionConfig, weights): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights ) self.pre_layrnorm = nn.LayerNorm.load( prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps ) self.encoder = CLIPEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) # self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): r""" Returns: """ if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs # pooled_output = last_hidden_state[:, 0, :] # pooled_output = self.post_layernorm(pooled_output) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, ) class CLIPVisionModel(CLIPPreTrainedModel): config_class = CLIPVisionConfig main_input_name = "pixel_values" _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPVisionConfig): super().__init__(config) self.vision_model = CLIPVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, CLIPVisionModel >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) return self.vision_model( pixel_values=pixel_values, ) class CLIPModel(nn.Module): def __init__(self, prefix, config: CLIPConfig, weights): super().__init__() text_config = config.text_config vision_config = config.vision_config self.projection_dim = config.projection_dim self.text_embed_dim = text_config.hidden_size self.vision_embed_dim = vision_config.hidden_size self.text_model = CLIPTextTransformer(text_config) self.vision_model = CLIPVisionTransformer(vision_config) self.visual_projection = nn.Linear( self.vision_embed_dim, self.projection_dim, bias=False ) self.text_projection = nn.Linear( self.text_embed_dim, self.projection_dim, bias=False ) self.logit_scale = nn.Parameter( torch.tensor(self.config.logit_scale_init_value) ) # Initialize weights and apply final processing self.post_init() def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, CLIPModel >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") >>> text_features = model.get_text_features(**inputs) ```""" text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, ) pooled_output = text_outputs[1] text_features = self.text_projection(pooled_output) return text_features def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, CLIPModel >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> image_features = model.get_image_features(**inputs) ```""" # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. vision_outputs = self.vision_model( pixel_values=pixel_values, ) pooled_output = vision_outputs[1] # pooled_output image_features = self.visual_projection(pooled_output) return image_features def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ): r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, CLIPModel >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor( ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True ... ) >>> outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities ```""" # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. vision_outputs = self.vision_model( pixel_values=pixel_values, return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=return_dict, ) image_embeds = vision_outputs[1] image_embeds = self.visual_projection(image_embeds) text_embeds = text_outputs[1] text_embeds = self.text_projection(text_embeds) # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() return logits_per_image, logits_per_text