| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | # coding=utf-8 | 
					
						
							|  |  |  | # Copyright 2024 the HuggingFace Inc. team. All rights reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							| 
									
										
										
										
											2025-09-02 13:35:42 +00:00
										 |  |  | """PyTorch Idefics2 model.""" | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-26 14:29:09 +00:00
										 |  |  | from typing import List, Optional, Tuple | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.utils.checkpoint | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | import math | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from transformers.activations import ACT2FN | 
					
						
							|  |  |  | from text_generation_server.models.custom_modeling.vlm import ( | 
					
						
							|  |  |  |     load_text_model, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-08-29 14:29:01 +00:00
										 |  |  | from text_generation_server.layers.attention import Seqlen | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-13 10:44:30 +00:00
										 |  |  | from text_generation_server.layers import ( | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |     TensorParallelColumnLinear, | 
					
						
							|  |  |  |     TensorParallelEmbedding, | 
					
						
							|  |  |  |     TensorParallelRowLinear, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
											  
											
												Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
											
										 
											2024-07-19 07:37:39 +00:00
										 |  |  | from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | 
					
						
							|  |  |  |     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     batch, num_key_value_heads, slen, head_dim = hidden_states.shape | 
					
						
							|  |  |  |     if n_rep == 1: | 
					
						
							|  |  |  |         return hidden_states | 
					
						
							|  |  |  |     hidden_states = hidden_states[:, :, None, :, :].expand( | 
					
						
							|  |  |  |         batch, num_key_value_heads, n_rep, slen, head_dim | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2VisionEmbeddings(nn.Module): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable | 
					
						
							|  |  |  |     resolution. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) | 
					
						
							|  |  |  |     which allows treating images in their native aspect ratio and without the need to resize them to the same | 
					
						
							|  |  |  |     fixed size. In particular, we start from the original pre-trained SigLIP model | 
					
						
							|  |  |  |     (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.embed_dim = config.hidden_size | 
					
						
							|  |  |  |         self.image_size = config.image_size | 
					
						
							|  |  |  |         self.patch_size = config.patch_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.patch_embedding = nn.Conv2d( | 
					
						
							|  |  |  |             in_channels=config.num_channels, | 
					
						
							|  |  |  |             out_channels=self.embed_dim, | 
					
						
							|  |  |  |             kernel_size=self.patch_size, | 
					
						
							|  |  |  |             stride=self.patch_size, | 
					
						
							|  |  |  |             padding="valid", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.patch_embedding.weight = nn.Parameter( | 
					
						
							|  |  |  |             weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.patch_embedding.bias = nn.Parameter( | 
					
						
							|  |  |  |             weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.num_patches_per_side = self.image_size // self.patch_size | 
					
						
							|  |  |  |         self.num_patches = self.num_patches_per_side**2 | 
					
						
							|  |  |  |         self.num_positions = self.num_patches | 
					
						
							|  |  |  |         self.position_embedding = TensorParallelEmbedding( | 
					
						
							|  |  |  |             prefix=f"{prefix}.position_embedding", weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							|  |  |  |         batch_size, _, max_im_h, max_im_w = pixel_values.shape | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         patch_embeds = self.patch_embedding(pixel_values) | 
					
						
							|  |  |  |         embeddings = patch_embeds.flatten(2).transpose(1, 2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         max_nb_patches_h, max_nb_patches_w = ( | 
					
						
							|  |  |  |             max_im_h // self.patch_size, | 
					
						
							|  |  |  |             max_im_w // self.patch_size, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         boundaries = torch.arange( | 
					
						
							|  |  |  |             1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         position_ids = torch.full( | 
					
						
							|  |  |  |             size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for batch_idx, p_attn_mask in enumerate(patch_attention_mask): | 
					
						
							|  |  |  |             nb_patches_h = p_attn_mask[:, 0].sum() | 
					
						
							|  |  |  |             nb_patches_w = p_attn_mask[0].sum() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) | 
					
						
							|  |  |  |             fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             bucket_coords_h = torch.bucketize( | 
					
						
							|  |  |  |                 fractional_coords_h, boundaries, right=True | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             bucket_coords_w = torch.bucketize( | 
					
						
							|  |  |  |                 fractional_coords_w, boundaries, right=True | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             pos_ids = ( | 
					
						
							|  |  |  |                 bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w | 
					
						
							|  |  |  |             ).flatten() | 
					
						
							|  |  |  |             position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         position_ids = position_ids.to(self.position_embedding.weight.device) | 
					
						
							|  |  |  |         embeddings = embeddings + self.position_embedding(position_ids) | 
					
						
							|  |  |  |         return embeddings | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2VisionAttention(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.config = config | 
					
						
							|  |  |  |         self.embed_dim = config.hidden_size | 
					
						
							|  |  |  |         self.num_heads = config.num_attention_heads | 
					
						
							|  |  |  |         self.head_size = self.embed_dim // self.num_heads | 
					
						
							|  |  |  |         if self.head_size * self.num_heads != self.embed_dim: | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" | 
					
						
							|  |  |  |                 f" {self.num_heads})." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         self.scale = self.head_size**-0.5 | 
					
						
							|  |  |  |         self.dropout = config.attention_dropout | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.num_heads = self.num_heads // weights.process_group.size() | 
					
						
							|  |  |  |         self.embed_dim = self.embed_dim // weights.process_group.size() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.qkv = TensorParallelColumnLinear.load_multi( | 
					
						
							|  |  |  |             config, | 
					
						
							|  |  |  |             prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], | 
					
						
							|  |  |  |             dim=0, | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             bias=True, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.out_proj = TensorParallelRowLinear.load( | 
					
						
							|  |  |  |             config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.is_causal = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         hidden_states: torch.Tensor, | 
					
						
							|  |  |  |         attention_mask: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							|  |  |  |         batch_size, q_len, _ = hidden_states.size() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         qkv = self.qkv(hidden_states) | 
					
						
							|  |  |  |         query_states, key_states, value_states = qkv.split( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 self.head_size * self.num_heads, | 
					
						
							|  |  |  |                 self.head_size * self.num_heads, | 
					
						
							|  |  |  |                 self.head_size * self.num_heads, | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             dim=2, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         query_states = query_states.view( | 
					
						
							|  |  |  |             batch_size, q_len, self.num_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  |         key_states = key_states.view( | 
					
						
							|  |  |  |             batch_size, q_len, self.num_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  |         value_states = value_states.view( | 
					
						
							|  |  |  |             batch_size, q_len, self.num_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         k_v_seq_len = key_states.shape[-2] | 
					
						
							|  |  |  |         attn_weights = ( | 
					
						
							|  |  |  |             torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" | 
					
						
							|  |  |  |                 f" {attn_weights.size()}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attention_mask is not None: | 
					
						
							|  |  |  |             if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): | 
					
						
							|  |  |  |                 raise ValueError( | 
					
						
							|  |  |  |                     f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             attn_weights = attn_weights + attention_mask | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # upcast attention to fp32 | 
					
						
							|  |  |  |         attn_weights = nn.functional.softmax( | 
					
						
							|  |  |  |             attn_weights, dim=-1, dtype=torch.float32 | 
					
						
							|  |  |  |         ).to(query_states.dtype) | 
					
						
							|  |  |  |         attn_weights = nn.functional.dropout( | 
					
						
							|  |  |  |             attn_weights, p=self.dropout, training=self.training | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         attn_output = torch.matmul(attn_weights, value_states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" | 
					
						
							|  |  |  |                 f" {attn_output.size()}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_output = attn_output.transpose(1, 2).contiguous() | 
					
						
							|  |  |  |         attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_output = self.out_proj(attn_output) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return attn_output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2VisionMLP(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.config = config | 
					
						
							|  |  |  |         self.activation_fn = ACT2FN[config.hidden_act] | 
					
						
							|  |  |  |         self.fc1 = TensorParallelColumnLinear.load( | 
					
						
							|  |  |  |             prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.fc2 = TensorParallelRowLinear.load( | 
					
						
							|  |  |  |             prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         hidden_states = self.fc1(hidden_states) | 
					
						
							|  |  |  |         hidden_states = self.activation_fn(hidden_states) | 
					
						
							|  |  |  |         hidden_states = self.fc2(hidden_states) | 
					
						
							|  |  |  |         return hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2EncoderLayer(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.embed_dim = config.hidden_size | 
					
						
							|  |  |  |         self.self_attn = Idefics2VisionAttention( | 
					
						
							|  |  |  |             prefix=f"{prefix}.self_attn", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.layer_norm1 = nn.LayerNorm.load( | 
					
						
							|  |  |  |             prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.layer_norm2 = nn.LayerNorm.load( | 
					
						
							|  |  |  |             prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.mlp = Idefics2VisionMLP( | 
					
						
							|  |  |  |             prefix=f"{prefix}.mlp", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         hidden_states: torch.Tensor, | 
					
						
							|  |  |  |         attention_mask: torch.Tensor, | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							|  |  |  |         residual = hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = self.layer_norm1(hidden_states) | 
					
						
							|  |  |  |         hidden_states = self.self_attn( | 
					
						
							|  |  |  |             hidden_states=hidden_states, | 
					
						
							|  |  |  |             attention_mask=attention_mask, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         hidden_states = residual + hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         residual = hidden_states | 
					
						
							|  |  |  |         hidden_states = self.layer_norm2(hidden_states) | 
					
						
							|  |  |  |         hidden_states = self.mlp(hidden_states) | 
					
						
							|  |  |  |         hidden_states = residual + hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2Encoder(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.config = config | 
					
						
							|  |  |  |         self.layers = nn.ModuleList( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 Idefics2EncoderLayer( | 
					
						
							|  |  |  |                     prefix=f"{prefix}.layers.{i}", config=config, weights=weights | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 for i in range(config.num_hidden_layers) | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Ignore copy | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         inputs_embeds, | 
					
						
							|  |  |  |         attention_mask: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         hidden_states = inputs_embeds | 
					
						
							|  |  |  |         for encoder_layer in self.layers: | 
					
						
							|  |  |  |             hidden_states = encoder_layer( | 
					
						
							|  |  |  |                 hidden_states, | 
					
						
							|  |  |  |                 attention_mask, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2VisionTransformer(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.config = config | 
					
						
							|  |  |  |         self.embeddings = Idefics2VisionEmbeddings( | 
					
						
							|  |  |  |             prefix=f"{prefix}.embeddings", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.encoder = Idefics2Encoder( | 
					
						
							|  |  |  |             prefix=f"{prefix}.encoder", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.post_layernorm = nn.LayerNorm.load( | 
					
						
							|  |  |  |             prefix=f"{prefix}.post_layernorm", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             eps=config.layer_norm_eps, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         pixel_values, | 
					
						
							|  |  |  |         patch_attention_mask: Optional[torch.BoolTensor] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         batch_size = pixel_values.size(0) | 
					
						
							|  |  |  |         if patch_attention_mask is None: | 
					
						
							|  |  |  |             patch_size = self.config.patch_size | 
					
						
							|  |  |  |             patch_attention_mask = torch.ones( | 
					
						
							|  |  |  |                 ( | 
					
						
							|  |  |  |                     batch_size, | 
					
						
							|  |  |  |                     pixel_values.size(2) // patch_size, | 
					
						
							|  |  |  |                     pixel_values.size(3) // patch_size, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             patch_attention_mask = patch_attention_mask.to( | 
					
						
							|  |  |  |                 dtype=torch.bool, device=pixel_values.device | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = self.embeddings( | 
					
						
							|  |  |  |             pixel_values=pixel_values, patch_attention_mask=patch_attention_mask | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         patch_attention_mask = patch_attention_mask.view(batch_size, -1) | 
					
						
							|  |  |  |         # The call to `_upad_input` in `_flash_attention_forward` is expensive | 
					
						
							|  |  |  |         # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), | 
					
						
							|  |  |  |         # avoiding passing the attention_mask, which is equivalent to attending to the full sequence | 
					
						
							|  |  |  |         if not torch.any(~patch_attention_mask): | 
					
						
							|  |  |  |             patch_attention_mask = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             patch_attention_mask = _prepare_4d_attention_mask( | 
					
						
							|  |  |  |                 patch_attention_mask, hidden_states.dtype | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         encoder_outputs = self.encoder( | 
					
						
							|  |  |  |             inputs_embeds=hidden_states, | 
					
						
							|  |  |  |             attention_mask=patch_attention_mask, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         last_hidden_state = encoder_outputs | 
					
						
							|  |  |  |         last_hidden_state = self.post_layernorm(last_hidden_state) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return last_hidden_state | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2MLP(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         act = config.text_config.hidden_act | 
					
						
							|  |  |  |         self.act = ( | 
					
						
							|  |  |  |             ACT2FN[act] | 
					
						
							|  |  |  |             if "gelu" not in act | 
					
						
							|  |  |  |             else lambda x: torch.nn.functional.gelu( | 
					
						
							|  |  |  |                 x, | 
					
						
							|  |  |  |                 approximate=( | 
					
						
							|  |  |  |                     "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.gate_up_proj = TensorParallelColumnLinear.load_multi( | 
					
						
							|  |  |  |             config, | 
					
						
							|  |  |  |             prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             dim=0, | 
					
						
							|  |  |  |             bias=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.down_proj = TensorParallelRowLinear.load( | 
					
						
							|  |  |  |             config, | 
					
						
							|  |  |  |             prefix=f"{prefix}.down_proj", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             bias=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, hidden_states): | 
					
						
							|  |  |  |         start_shape = hidden_states.shape[:-1] | 
					
						
							|  |  |  |         gate_up_states = self.gate_up_proj(hidden_states) | 
					
						
							|  |  |  |         intermediate_size = gate_up_states.shape[-1] // 2 | 
					
						
							|  |  |  |         gate_up_states = gate_up_states.view(-1, 2, intermediate_size) | 
					
						
							|  |  |  |         return self.down_proj( | 
					
						
							|  |  |  |             self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] | 
					
						
							|  |  |  |         ).view(*start_shape, -1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2RMSNorm(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, weights, eps): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Idefics2RMSNorm is equivalent to T5LayerNorm | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.weight = nn.Parameter( | 
					
						
							|  |  |  |             weights.get_tensor(f"{prefix}.weight"), requires_grad=False | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.variance_epsilon = eps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, hidden_states): | 
					
						
							|  |  |  |         input_dtype = hidden_states.dtype | 
					
						
							|  |  |  |         hidden_states = hidden_states.to(torch.float32) | 
					
						
							|  |  |  |         variance = hidden_states.pow(2).mean(-1, keepdim=True) | 
					
						
							|  |  |  |         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | 
					
						
							|  |  |  |         return self.weight * hidden_states.to(input_dtype) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2PerceiverAttention(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.layer_idx = None | 
					
						
							|  |  |  |         self.hidden_size = config.text_config.hidden_size | 
					
						
							|  |  |  |         self.num_heads = config.perceiver_config.resampler_n_heads | 
					
						
							|  |  |  |         self.head_size = config.perceiver_config.resampler_head_dim | 
					
						
							|  |  |  |         self.num_key_value_heads = config.perceiver_config.num_key_value_heads | 
					
						
							|  |  |  |         self.num_key_value_groups = self.num_heads // self.num_key_value_heads | 
					
						
							|  |  |  |         self.attention_dropout = config.perceiver_config.attention_dropout | 
					
						
							|  |  |  |         self.num_heads = self.num_heads // weights.process_group.size() | 
					
						
							|  |  |  |         self.num_key_value_heads = ( | 
					
						
							|  |  |  |             self.num_key_value_heads // weights.process_group.size() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.q_proj = TensorParallelColumnLinear.load( | 
					
						
							|  |  |  |             config, | 
					
						
							|  |  |  |             prefix=f"{prefix}.q_proj", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             bias=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.kv = TensorParallelColumnLinear.load_multi( | 
					
						
							|  |  |  |             config, | 
					
						
							|  |  |  |             prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], | 
					
						
							|  |  |  |             dim=0, | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             bias=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.o_proj = TensorParallelRowLinear.load( | 
					
						
							|  |  |  |             config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.is_causal = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         latents: torch.Tensor, | 
					
						
							|  |  |  |         context: torch.Tensor, | 
					
						
							|  |  |  |         attention_mask: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | 
					
						
							|  |  |  |         bsz, q_len, _ = latents.size() | 
					
						
							|  |  |  |         kv_seq_len = q_len + context.size()[1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = torch.concat([context, latents], dim=-2) | 
					
						
							|  |  |  |         query_states = self.q_proj(latents) | 
					
						
							|  |  |  |         kv = self.kv(hidden_states) | 
					
						
							|  |  |  |         key_states, value_states = kv.split( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 self.head_size * self.num_key_value_heads, | 
					
						
							|  |  |  |                 self.head_size * self.num_key_value_heads, | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             dim=2, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         query_states = query_states.view( | 
					
						
							|  |  |  |             bsz, q_len, self.num_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  |         key_states = key_states.view( | 
					
						
							|  |  |  |             bsz, kv_seq_len, self.num_key_value_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  |         value_states = value_states.view( | 
					
						
							|  |  |  |             bsz, kv_seq_len, self.num_key_value_heads, self.head_size | 
					
						
							|  |  |  |         ).transpose(1, 2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # repeat k/v heads if n_kv_heads < n_heads | 
					
						
							|  |  |  |         key_states = repeat_kv(key_states, self.num_key_value_groups) | 
					
						
							|  |  |  |         value_states = repeat_kv(value_states, self.num_key_value_groups) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_weights = torch.matmul( | 
					
						
							|  |  |  |             query_states, key_states.transpose(2, 3) | 
					
						
							|  |  |  |         ) / math.sqrt(self.head_size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | 
					
						
							|  |  |  |                 f" {attn_weights.size()}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attention_mask is not None: | 
					
						
							|  |  |  |             if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | 
					
						
							|  |  |  |                 raise ValueError( | 
					
						
							|  |  |  |                     f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             attn_weights = attn_weights + attention_mask | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # upcast attention to fp32 | 
					
						
							|  |  |  |         attn_weights = nn.functional.softmax( | 
					
						
							|  |  |  |             attn_weights, dim=-1, dtype=torch.float32 | 
					
						
							|  |  |  |         ).to(query_states.dtype) | 
					
						
							|  |  |  |         attn_output = torch.matmul(attn_weights, value_states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" | 
					
						
							|  |  |  |                 f" {attn_output.size()}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_output = attn_output.transpose(1, 2).contiguous() | 
					
						
							|  |  |  |         attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_output = self.o_proj(attn_output) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return attn_output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2PerceiverLayer(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.hidden_size = config.text_config.hidden_size | 
					
						
							|  |  |  |         self.n_latents = config.perceiver_config.resampler_n_latents | 
					
						
							|  |  |  |         self.depth = config.perceiver_config.resampler_depth | 
					
						
							|  |  |  |         self.rms_norm_eps = config.text_config.rms_norm_eps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.input_latents_norm = Idefics2RMSNorm( | 
					
						
							|  |  |  |             prefix=f"{prefix}.input_latents_norm", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             eps=self.rms_norm_eps, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.input_context_norm = Idefics2RMSNorm( | 
					
						
							|  |  |  |             prefix=f"{prefix}.input_context_norm", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             eps=self.rms_norm_eps, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.self_attn = Idefics2PerceiverAttention( | 
					
						
							|  |  |  |             prefix=f"{prefix}.self_attn", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.post_attention_layernorm = Idefics2RMSNorm( | 
					
						
							|  |  |  |             prefix=f"{prefix}.post_attention_layernorm", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             eps=self.rms_norm_eps, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         latents: torch.Tensor, | 
					
						
							|  |  |  |         context: torch.Tensor, | 
					
						
							|  |  |  |         attention_mask: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | 
					
						
							|  |  |  |             context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | 
					
						
							|  |  |  |             attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | 
					
						
							|  |  |  |                 `(batch, sequence_length)` where padding elements are indicated by 0. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         residual = latents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         latents = self.input_latents_norm(latents) | 
					
						
							|  |  |  |         context = self.input_context_norm(context) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         latents = self.self_attn( | 
					
						
							|  |  |  |             latents=latents, | 
					
						
							|  |  |  |             context=context, | 
					
						
							|  |  |  |             attention_mask=attention_mask, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         latents = residual + latents | 
					
						
							|  |  |  |         residual = latents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         latents = self.post_attention_layernorm(latents) | 
					
						
							|  |  |  |         latents = self.mlp(latents) | 
					
						
							|  |  |  |         latents = residual + latents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return latents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2PerceiverResampler(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights) -> None: | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.hidden_size = config.text_config.hidden_size | 
					
						
							|  |  |  |         self.hidden_act = config.perceiver_config.hidden_act | 
					
						
							|  |  |  |         self.n_latents = config.perceiver_config.resampler_n_latents | 
					
						
							|  |  |  |         self.depth = config.perceiver_config.resampler_depth | 
					
						
							|  |  |  |         self.rms_norm_eps = config.text_config.rms_norm_eps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Create Latents for Perceiver | 
					
						
							|  |  |  |         self.latents = weights.get_tensor(f"{prefix}.latents") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Create Transformer Blocks | 
					
						
							|  |  |  |         self.layers = nn.ModuleList( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 Idefics2PerceiverLayer( | 
					
						
							|  |  |  |                     prefix=f"{prefix}.layers.{idx}", config=config, weights=weights | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 for idx in range(self.depth) | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.norm = Idefics2RMSNorm( | 
					
						
							|  |  |  |             prefix=f"{prefix}.norm", | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             eps=config.text_config.rms_norm_eps, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         context: torch.Tensor, | 
					
						
							|  |  |  |         attention_mask, | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							|  |  |  |         # seq embed -> bsz seq embed | 
					
						
							|  |  |  |         latents = self.latents.unsqueeze(0).expand( | 
					
						
							|  |  |  |             (context.shape[0], *self.latents.size()) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         latent_attention_mask = torch.ones( | 
					
						
							|  |  |  |             (attention_mask.size(0), latents.size(1)), | 
					
						
							|  |  |  |             dtype=attention_mask.dtype, | 
					
						
							|  |  |  |             device=attention_mask.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) | 
					
						
							|  |  |  |         attention_mask = _prepare_4d_attention_mask( | 
					
						
							|  |  |  |             attention_mask, latents.dtype, tgt_len=self.n_latents | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         compressed_context = latents | 
					
						
							|  |  |  |         for perceiver_layer in self.layers: | 
					
						
							|  |  |  |             compressed_context = perceiver_layer( | 
					
						
							|  |  |  |                 compressed_context, | 
					
						
							|  |  |  |                 context, | 
					
						
							|  |  |  |                 attention_mask=attention_mask, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         compressed_context = self.norm(compressed_context) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return compressed_context | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2Connector(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.modality_projection = Idefics2MLP( | 
					
						
							|  |  |  |             prefix=f"{prefix}.modality_projection", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.perceiver_resampler = Idefics2PerceiverResampler( | 
					
						
							|  |  |  |             prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, image_hidden_states, attention_mask): | 
					
						
							|  |  |  |         image_hidden_states = self.modality_projection(image_hidden_states) | 
					
						
							|  |  |  |         image_hidden_states = self.perceiver_resampler( | 
					
						
							|  |  |  |             context=image_hidden_states, attention_mask=attention_mask | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return image_hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Idefics2ForConditionalGeneration(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, prefix, config, weights): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2024-07-16 05:58:25 +00:00
										 |  |  |         config.vision_config.quantize = None | 
					
						
							| 
									
										
										
										
											2024-05-14 10:33:18 +00:00
										 |  |  |         config.vision_config.speculator = config.speculator | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         config.text_config.quantize = config.quantize | 
					
						
							| 
									
										
										
										
											2024-05-14 10:33:18 +00:00
										 |  |  |         config.text_config.speculator = config.speculator | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         vision_config = config.vision_config | 
					
						
							|  |  |  |         self.text_model = load_text_model( | 
					
						
							|  |  |  |             prefix="model" if not prefix else f"{prefix}.model", | 
					
						
							|  |  |  |             config=config.text_config, | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |             name="text_model", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.dtype = weights.dtype | 
					
						
							| 
									
										
										
										
											2024-07-16 05:58:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # The vision and connector models are not quantized. | 
					
						
							| 
									
										
											  
											
												Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
											
										 
											2024-07-19 07:37:39 +00:00
										 |  |  |         with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): | 
					
						
							| 
									
										
										
										
											2024-07-16 05:58:25 +00:00
										 |  |  |             self.vision_model = Idefics2VisionTransformer( | 
					
						
							|  |  |  |                 prefix=( | 
					
						
							|  |  |  |                     f"{prefix}.model.vision_model" if prefix else "model.vision_model" | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |                 config=vision_config, | 
					
						
							|  |  |  |                 weights=weights, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
											  
											
												Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
											
										 
											2024-07-19 07:37:39 +00:00
										 |  |  |             config.quantize = None | 
					
						
							|  |  |  |             self.connector = Idefics2Connector( | 
					
						
							|  |  |  |                 prefix=f"{prefix}.model.connector" if prefix else "model.connector", | 
					
						
							|  |  |  |                 config=config, | 
					
						
							|  |  |  |                 weights=weights, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-07-16 05:58:25 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         self.config = config | 
					
						
							|  |  |  |         self.image_seq_len = config.perceiver_config.resampler_n_latents | 
					
						
							|  |  |  |         self.image_token_id = config.image_token_id | 
					
						
							|  |  |  |         self.pad_token_id = ( | 
					
						
							|  |  |  |             config.pad_token_id if config.pad_token_id is not None else -1 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _merge_input_ids_with_image_features( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         input_ids: torch.Tensor, | 
					
						
							|  |  |  |         inputs_embeds: torch.Tensor, | 
					
						
							|  |  |  |         image_features: torch.Tensor, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """In place merges in vision_embeddings with inputs_embeds.""" | 
					
						
							|  |  |  |         # mask = input_ids == self.config.image_token_index | 
					
						
							|  |  |  |         mask = input_ids == self.config.image_token_id | 
					
						
							|  |  |  |         # Let's pray we have enabled enough slots ! | 
					
						
							|  |  |  |         inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) | 
					
						
							|  |  |  |         return inputs_embeds | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-06 16:01:59 +00:00
										 |  |  |     def get_vision_embeds( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         pixel_values: torch.FloatTensor, | 
					
						
							|  |  |  |         pixel_attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
							|  |  |  |         image_sizes: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |         image_grid_thw: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         assert pixel_values is not None | 
					
						
							|  |  |  |         batch_size, num_images, num_channels, height, width = pixel_values.shape | 
					
						
							|  |  |  |         all_states = [] | 
					
						
							|  |  |  |         all_pixel_values = pixel_values | 
					
						
							|  |  |  |         all_pixel_mask = pixel_attention_mask | 
					
						
							|  |  |  |         for i in range(batch_size): | 
					
						
							|  |  |  |             pixel_values = all_pixel_values.to(dtype=self.dtype)  # fp16 compatibility | 
					
						
							|  |  |  |             pixel_values = pixel_values[i : i + 1] | 
					
						
							|  |  |  |             pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Remove padding images - padding images are full 0. | 
					
						
							|  |  |  |             nb_values_per_image = pixel_values.shape[1:].numel() | 
					
						
							|  |  |  |             real_images_inds = (pixel_values == 0.0).sum( | 
					
						
							|  |  |  |                 dim=(-1, -2, -3) | 
					
						
							|  |  |  |             ) != nb_values_per_image | 
					
						
							|  |  |  |             pixel_values = pixel_values[real_images_inds].contiguous() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Handle the vision attention mask | 
					
						
							|  |  |  |             if pixel_attention_mask is None: | 
					
						
							|  |  |  |                 pixel_attention_mask = torch.ones( | 
					
						
							|  |  |  |                     size=( | 
					
						
							|  |  |  |                         pixel_values.size(0), | 
					
						
							|  |  |  |                         pixel_values.size(2), | 
					
						
							|  |  |  |                         pixel_values.size(3), | 
					
						
							|  |  |  |                     ), | 
					
						
							|  |  |  |                     dtype=torch.bool, | 
					
						
							|  |  |  |                     device=pixel_values.device, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # Remove padding images from the mask/pP p | 
					
						
							|  |  |  |                 pixel_attention_mask = all_pixel_mask[i : i + 1] | 
					
						
							|  |  |  |                 pixel_attention_mask = pixel_attention_mask.view( | 
					
						
							|  |  |  |                     1 * num_images, *pixel_attention_mask.shape[2:] | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 pixel_attention_mask = pixel_attention_mask[ | 
					
						
							|  |  |  |                     real_images_inds | 
					
						
							|  |  |  |                 ].contiguous() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             patch_size = self.config.vision_config.patch_size | 
					
						
							|  |  |  |             patches_subgrid = pixel_attention_mask.unfold( | 
					
						
							|  |  |  |                 dimension=1, size=patch_size, step=patch_size | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             patches_subgrid = patches_subgrid.unfold( | 
					
						
							|  |  |  |                 dimension=2, size=patch_size, step=patch_size | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Get sequence from the vision encoder | 
					
						
							|  |  |  |             image_hidden_states = self.vision_model( | 
					
						
							|  |  |  |                 pixel_values=pixel_values, | 
					
						
							|  |  |  |                 patch_attention_mask=patch_attention_mask, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Modality projection & resampling | 
					
						
							|  |  |  |             image_hidden_states = self.connector( | 
					
						
							|  |  |  |                 image_hidden_states, | 
					
						
							|  |  |  |                 attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             all_states.append(image_hidden_states) | 
					
						
							|  |  |  |         image_hidden_states = torch.stack(all_states, dim=0) | 
					
						
							|  |  |  |         return image_hidden_states.view(-1, image_hidden_states.shape[-1]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_inputs_embeds( | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         self, | 
					
						
							|  |  |  |         input_ids: torch.Tensor, | 
					
						
							| 
									
										
										
										
											2025-05-06 16:01:59 +00:00
										 |  |  |         vision_embeds: torch.Tensor = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         inputs_embeds = self.text_model.embed_tokens(input_ids) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if vision_embeds is not None: | 
					
						
							|  |  |  |             # When we generate, we don't want to replace the potential image_token_id that we generated by images | 
					
						
							|  |  |  |             # that simply don't exist | 
					
						
							|  |  |  |             inputs_embeds = self._merge_input_ids_with_image_features( | 
					
						
							|  |  |  |                 input_ids, inputs_embeds, vision_embeds | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return inputs_embeds | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         inputs_embeds: torch.Tensor, | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         position_ids: torch.Tensor, | 
					
						
							|  |  |  |         cu_seqlen_prefill: Optional[torch.Tensor], | 
					
						
							|  |  |  |         kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], | 
					
						
							|  |  |  |         block_tables: torch.Tensor, | 
					
						
							|  |  |  |         slots: torch.Tensor, | 
					
						
							| 
									
										
										
										
											2024-08-29 14:29:01 +00:00
										 |  |  |         seqlen: Seqlen, | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         max_s: int, | 
					
						
							|  |  |  |         prefill_cache_indices: Optional[torch.Tensor], | 
					
						
							|  |  |  |         lm_head_indices: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |         # Unused here | 
					
						
							| 
									
										
										
										
											2025-05-06 16:01:59 +00:00
										 |  |  |         attention_mask: Optional[torch.BoolTensor] = None, | 
					
						
							| 
									
										
										
										
											2024-06-25 18:46:27 +00:00
										 |  |  |         adapter_data: Optional[torch.Tensor] = None, | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |     ): | 
					
						
							|  |  |  |         hidden_states = self.text_model.model( | 
					
						
							|  |  |  |             inputs_embeds=inputs_embeds, | 
					
						
							|  |  |  |             position_ids=position_ids, | 
					
						
							|  |  |  |             cu_seqlen_prefill=cu_seqlen_prefill, | 
					
						
							|  |  |  |             kv_cache=kv_cache, | 
					
						
							|  |  |  |             block_tables=block_tables, | 
					
						
							|  |  |  |             slots=slots, | 
					
						
							| 
									
										
										
										
											2024-08-29 14:29:01 +00:00
										 |  |  |             seqlen=seqlen, | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |             max_s=max_s, | 
					
						
							|  |  |  |             true_max_s=max_s, | 
					
						
							|  |  |  |             prefill_cache_indices=None, | 
					
						
							| 
									
										
										
										
											2024-07-24 08:39:08 +00:00
										 |  |  |             adapter_data=adapter_data, | 
					
						
							| 
									
										
										
										
											2024-04-23 21:04:44 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |         if lm_head_indices is not None: | 
					
						
							|  |  |  |             hidden_states = hidden_states[lm_head_indices] | 
					
						
							|  |  |  |         logits, speculative_logits = self.text_model.lm_head(hidden_states) | 
					
						
							|  |  |  |         return logits, speculative_logits |