# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mllama model.""" from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn import math from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_outputs import ( CausalLMOutputWithPast, BaseModelOutputWithPast, ) from transformers.cache_utils import ( StaticCache, DynamicCache, ) from transformers.modeling_attn_mask_utils import AttentionMaskConverter import torch.nn.functional as F from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, SpeculativeHead, FastLinear, ) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, num_patches: int, target_length: int, dtype: torch.dtype, ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) attention_mask = attention_mask.repeat(1, 1, target_length, 1) # Mask padding patches pad_patches = target_length - num_patches attention_mask[:, :, -pad_patches:] = 0 # Invert the mask (0 -> 1, 1 -> 0) attention_mask = 1 - attention_mask # Reshape to 2D and create 4D attention mask # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) attention_mask = attention_mask.reshape( batch_size, max_num_tiles * target_length, 1 ) attention_mask = ( attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min ) attention_mask = attention_mask.unsqueeze(1) return attention_mask # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = ( causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) return causal_mask def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, dtype: str, ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape so it can be used by attn module batch_size, text_total_length, *_ = cross_attention_mask.shape cross_attention_mask = cross_attention_mask.repeat_interleave( num_vision_tokens, dim=3 ) cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) cross_attention_mask = cross_attention_mask.unsqueeze(1) # invert the mask inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) cross_attention_mask = inverted_cross_attn_mask.masked_fill( inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's # last dimension contains negative infinity values, otherwise it's 1 negative_inf_value = torch.finfo(dtype).min full_text_row_masked_out_mask = ( (cross_attention_mask != negative_inf_value) .any(dim=-1) .type_as(cross_attention_mask)[..., None] ) cross_attention_mask *= full_text_row_masked_out_mask return cross_attention_mask, full_text_row_masked_out_mask # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision class MllamaVisionMLP(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = TensorParallelColumnLinear.load( prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True ) self.fc2 = TensorParallelRowLinear.load( prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class MllamaVisionSdpaAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.embed_dim = config.hidden_size self.head_dim = config.hidden_size // config.attention_heads self.num_heads = config.attention_heads // weights.process_group.size() self.qkv_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_state) query, key, value = qkv.split( [ self.head_dim * self.num_heads, self.head_dim * self.num_heads, self.head_dim * self.num_heads, ], dim=2, ) batch_size, q_seq_len, _ = query.shape _, kv_seq_len, _ = key.shape query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) output = self.o_proj(attn_output) return output class MllamaVisionEncoderLayer(nn.Module): def __init__(self, *, prefix, config, weights, is_gated: bool): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.attention_heads self.is_gated = is_gated self.intermediate_size = config.intermediate_size self.self_attn = MllamaVisionSdpaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.mlp = MllamaVisionMLP( prefix=f"{prefix}.mlp", config=config, weights=weights ) self.input_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 ) self.post_attention_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 ) # there used to be an if else here, no code path if is_gated: self.gate_attn = nn.Parameter( weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False ) self.gate_ffn = nn.Parameter( weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() hidden_state = residual + gate_ffn * hidden_state return hidden_state class MllamaVisionEncoder(nn.Module): def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): super().__init__() self.config = config self.layers = [ MllamaVisionEncoderLayer( prefix=f"{prefix}.layers.{i}", config=config, weights=weights, is_gated=is_gated, ) for i in range(num_layers) ] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): encoder_states = [hidden_states] for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, ) hidden_states = layer_outputs encoder_states.append(hidden_states) return hidden_states, encoder_states class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.max_aspect_ratio_id = config.max_aspect_ratio_id self.embedding = TensorParallelEmbedding( prefix=f"{prefix}.embedding", weights=weights ) self.gate = nn.Parameter( weights.get_tensor(f"{prefix}.gate"), requires_grad=False ) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) # Always gated. embeddings = embeddings * self.gate.tanh() hidden_state = hidden_state + embeddings return hidden_state class MllamaPrecomputedPositionEmbedding(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 self.gate = nn.Parameter( weights.get_tensor(f"{prefix}.gate"), requires_grad=False ) # position embedding embedding = nn.Parameter( weights.get_tensor(f"{prefix}.embedding"), requires_grad=False ) self.gated_position_embedding = (1 - self.gate.tanh()) * embedding self.tile_embedding = TensorParallelEmbedding( prefix=f"{prefix}.tile_embedding", weights=weights ) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: # position embeddings hidden_state = hidden_state + self.gated_position_embedding.view( 1, 1, self.num_patches, self.hidden_size ) # precomputed tile position embeddings tile_position_embedding = self.tile_embedding(aspect_ratio_ids) batch_size = hidden_state.shape[0] tile_position_embedding = tile_position_embedding.reshape( batch_size, self.max_num_tiles, self.num_patches, self.hidden_size ) gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding hidden_state = hidden_state + gated_tile_position_embedding return hidden_state class MllamaVisionModel(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.num_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.dtype = weights.dtype self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", bias=False, ) self.patch_embedding.weight = nn.Parameter( weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False ) self.class_embedding = nn.Parameter( weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False ) self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( prefix=f"{prefix}.gated_positional_embedding", config=config, weights=weights, ) self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( prefix=f"{prefix}.pre_tile_positional_embedding", config=config, weights=weights, ) self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( prefix=f"{prefix}.post_tile_positional_embedding", config=config, weights=weights, ) ## layer norms self.layernorm_pre = nn.LayerNorm.load( prefix=f"{prefix}.layernorm_pre", weights=weights, # torch default eps=1e-05, ) self.layernorm_post = nn.LayerNorm.load( prefix=f"{prefix}.layernorm_post", weights=weights, # torch default eps=1e-05, ) ## encoders self.transformer = MllamaVisionEncoder( prefix=f"{prefix}.transformer", config=config, weights=weights, is_gated=False, num_layers=config.num_hidden_layers, ) self.global_transformer = MllamaVisionEncoder( prefix=f"{prefix}.global_transformer", config=config, weights=weights, is_gated=True, num_layers=config.num_global_layers, ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size, _, hidden_size = hidden_state.shape class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state def forward( self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( pixel_values.shape ) pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width ) aspect_ratio_ids = aspect_ratio_ids.reshape( batch_size * num_concurrent_media, -1 ) # patch embedding patch_embeds = self.patch_embedding(pixel_values) hidden_state = patch_embeds.flatten(2).transpose(1, 2) # tile embeddings _, num_patches, dim = hidden_state.shape hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, -1, dim ) hidden_state = self.pre_tile_positional_embedding( hidden_state, aspect_ratio_ids ) # apply cls token hidden_state = hidden_state.reshape( batch_size * num_concurrent_media * num_tiles, num_patches, dim ) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # apply position embeddings hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches, dim ) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) # apply encoder hidden_state = self.layernorm_pre(hidden_state) # Compute the number of tokens to pad num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 # Compute padding tuple for pad function padding = ( 0, 0, 0, num_padding_patches, ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None if attention_mask is not None: attention_mask = attention_mask.reshape( batch_size * num_concurrent_media, -1 ) attention_mask = _prepare_aspect_ratio_attention_mask( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, target_length=hidden_state.shape[2], dtype=self.dtype, ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) hidden_state, all_intermediate_hidden_states = self.transformer( hidden_state, attention_mask=attention_mask, ) intermediate_hidden_states = [ hidden_state for idx, hidden_state in enumerate(all_intermediate_hidden_states) if idx in self.intermediate_layers_indices ] intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = self.post_tile_positional_embedding( hidden_state, aspect_ratio_ids ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim, ) hidden_state, _ = self.global_transformer( hidden_state, attention_mask=attention_mask ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = hidden_state[:, :, :slice_index] # adding intermediate layer outputs hidden_state = hidden_state.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, dim ) intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1, ) intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, -1 ) hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) return hidden_state class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.num_heads = self.config.num_attention_heads self.num_key_value_heads = self.config.num_key_value_heads self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.layer_idx = layer_idx self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( self.num_key_value_heads // weights.process_group.size() ) self.q_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.q_proj", weights=weights, bias=False, ) self.k_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.k_proj", weights=weights, bias=False, ) self.v_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.v_proj", weights=weights, bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) self.q_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps ) self.k_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, past_key_value=None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) query_states = self.q_norm(query_states) if cross_attention_states is not None: key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) key_states = key_states.view( bsz, -1, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, -1, self.num_key_value_heads, self.head_dim ).transpose(1, 2) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) key_states = self.k_norm(key_states) if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position}, ) elif cache_position[0] != 0: key_states, value_states = ( past_key_value[self.layer_idx][0], past_key_value[self.layer_idx][1], ) else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText class MllamaTextMLP(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): shape = x.shape gate_up_states = self.gate_up_proj(x) gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) result = self.down_proj( self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] ) return result class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, *, prefix, config, weights, layer_idx) -> None: super().__init__() self.cross_attn = MllamaTextCrossAttention( prefix=f"{prefix}.cross_attn", config=config, weights=weights, layer_idx=layer_idx, ) self.input_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.cross_attn_attn_gate = torch.nn.Parameter( weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False ) self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) self.cross_attn_mlp_gate = torch.nn.Parameter( weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False ) self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, attention_mask: torch.Tensor, full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], past_key_value=None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: if past_key_value is not None: is_mixed = False if cross_attention_states is None: out_hidden_states = hidden_states[:] indices = [] for i, k in enumerate(past_key_value[self.layer_idx][0]): if isinstance(k, torch.Tensor): indices.append(i) from loguru import logger logger.info(f"Indices {indices}") if len(indices) == 0: return hidden_states is_mixed = True if len(indices) == hidden_states.shape[0]: is_mixed = False if is_mixed: hidden_states = hidden_states[indices] # Dirty hack _past_key_value = [None] * len(past_key_value) _past_key_value[self.layer_idx] = ( torch.stack( [ k for i, k in enumerate(past_key_value[self.layer_idx][0]) if i in indices ], dim=0, ), torch.stack( [ k for i, k in enumerate(past_key_value[self.layer_idx][1]) if i in indices ], dim=0, ), ) logger.info(f"Hidden states {hidden_states.shape}") logger.info(f"k {_past_key_value[self.layer_idx][0].shape}") logger.info(f"v {_past_key_value[self.layer_idx][1].shape}") past_key_value = _past_key_value residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, attn_weights, past_key_value = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, cache_position=cache_position, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if full_text_row_masked_out_mask is not None: hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states if is_mixed: out_hidden_states[indices] = hidden_states hidden_states = out_hidden_states from loguru import logger logger.info(f"After Hidden states {hidden_states.shape}") return hidden_states class MllamaTextSelfAttention(nn.Module): def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.num_heads = config.num_attention_heads self.dropout = config.dropout self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( self.num_key_value_heads // weights.process_group.size() ) self.layer_idx = layer_idx self.qkv_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, past_key_value=None, cache_position=None, **kwargs, ): bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv.split( [ self.head_dim * self.num_heads, self.head_dim * self.num_key_value_heads, self.head_dim * self.num_key_value_heads, ], dim=2, ) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, # TODO # attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class MllamaTextRMSNorm(nn.Module): def __init__(self, weight, eps): super().__init__() self.weight = weight self.variance_epsilon = eps @classmethod def load(cls, *, prefix, weights, eps): weight = nn.Parameter( weights.get_tensor(f"{prefix}.weight"), requires_grad=False ) return cls(weight=weight, eps=eps) def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT class MllamaSelfAttentionDecoderLayer(nn.Module): def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.self_attn = MllamaTextSelfAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_idx=layer_idx, ) self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value=None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class MllamaRotaryEmbedding(nn.Module): def __init__( self, *, config, weights, ): super().__init__() device = weights.device self.rope_type = config.rope_scaling["rope_type"] self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) inv_freq.to(device=device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 1, 2 ) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class MllamaTextModel(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) self.cross_attention_layers = config.cross_attention_layers self.layers = [] for layer_idx in range(config.num_hidden_layers): if layer_idx in self.cross_attention_layers: self.layers.append( MllamaCrossAttentionDecoderLayer( prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx, ) ) else: self.layers.append( MllamaSelfAttentionDecoderLayer( prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx, ) ) self.norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.FloatTensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers for idx, decoder_layer in enumerate(self.layers): # if ( # idx in self.cross_attention_layers # and cross_attention_states is None # and ( # past_key_values is None # or ( # past_key_values is not None # and any(past_key_values.get_seq_length(idx) == 0 # ) # ) # ): # continue layer_outputs = decoder_layer( hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, attention_mask=causal_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line # self.config._attn_implementation == "sdpa" and # if self.config._attn_implementation == "sdpa" and not using_static_cache: if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask class MllamaForCausalLM(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.vocab_size = config.vocab_size self.model = MllamaTextModel( prefix=f"{prefix}.model", config=config, weights=weights ) self.lm_head = SpeculativeHead.load( prefix=f"{prefix}.lm_head", config=config, weights=weights, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, full_text_row_masked_out_mask: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state # if lm_head_indices is not None: # hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return ( CausalLMOutputWithPast( logits=logits, past_key_values=outputs.past_key_values, ), speculative_logits, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif ( input_ids.shape[1] != cache_position.shape[0] ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. model_inputs = { "input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None, } if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs["input_ids"].shape device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_length(), dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "attention_mask": attention_mask, } ) return model_inputs class MllamaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = None config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator config.text_config._attn_implementation = "sdpa" self.hidden_size = config.text_config.hidden_size self.vision_model = MllamaVisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) self.language_model = MllamaForCausalLM( prefix="language_model", config=config.text_config, weights=weights ) self.multi_modal_projector = FastLinear.load( prefix="multi_modal_projector", config=config, weights=weights, bias=True ) self.config = config self.dtype = weights.dtype self.device = weights.device def forward( self, input_ids: torch.LongTensor = None, pixel_values: Optional[torch.FloatTensor] = None, aspect_ratio_mask: Optional[List[List[int]]] = None, aspect_ratio_ids: Optional[torch.Tensor] = None, attention_mask: Optional[List[List[List[int]]]] = None, cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, image_hidden_states=None, image_attention_mask=None, ): if past_key_values is None: past_key_values = DynamicCache( num_hidden_layers=self.config.text_config.num_hidden_layers ) elif isinstance(past_key_values, list): past_key_values = DynamicCache.from_legacy_cache(past_key_values) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) if pixel_values is not None and cross_attention_states is not None: raise ValueError( "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" ) if pixel_values is not None: if aspect_ratio_ids is None: raise ValueError( "`aspect_ratio_ids` must be provided if `pixel_values` is provided" ) # get vision tokens from vision model vision_states = self.vision_model( pixel_values, aspect_ratio_ids, aspect_ratio_mask ) cross_attention_states = self.multi_modal_projector(vision_states).reshape( -1, vision_states.shape[-2], self.hidden_size ) if cross_attention_mask is not None: cross_attention_mask, full_text_row_masked_out_mask = ( _prepare_cross_attention_mask( cross_attention_mask, num_vision_tokens=self.vision_model.num_patches, dtype=self.dtype, ) ) else: full_text_row_masked_out_mask = None if cross_attention_mask is not None and cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] full_text_row_masked_out_mask = full_text_row_masked_out_mask[ :, :, cache_position ] outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) return outputs