# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mllama model.""" from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex else: import flash_attn_2_cuda from transformers.activations import ACT2FN import torch.nn.functional as F from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, FastLinear, ) from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, num_patches: int, target_length: int, dtype: torch.dtype, ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) attention_mask = attention_mask.repeat(1, 1, target_length, 1) # Mask padding patches pad_patches = target_length - num_patches attention_mask[:, :, -pad_patches:] = 0 # Invert the mask (0 -> 1, 1 -> 0) attention_mask = 1 - attention_mask # Reshape to 2D and create 4D attention mask # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) attention_mask = attention_mask.reshape( batch_size, max_num_tiles * target_length, 1 ) attention_mask = ( attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min ) attention_mask = attention_mask.unsqueeze(1) return attention_mask # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = ( causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) return causal_mask def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, dtype: str, ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape so it can be used by attn module batch_size, text_total_length, *_ = cross_attention_mask.shape cross_attention_mask = cross_attention_mask.repeat_interleave( num_vision_tokens, dim=3 ) cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) cross_attention_mask = cross_attention_mask.unsqueeze(1) # invert the mask inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) cross_attention_mask = inverted_cross_attn_mask.masked_fill( inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's # last dimension contains negative infinity values, otherwise it's 1 negative_inf_value = torch.finfo(dtype).min full_text_row_masked_out_mask = ( (cross_attention_mask != negative_inf_value) .any(dim=-1) .type_as(cross_attention_mask)[..., None] ) cross_attention_mask *= full_text_row_masked_out_mask return cross_attention_mask, full_text_row_masked_out_mask # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision class MllamaVisionMLP(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = TensorParallelColumnLinear.load( prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True ) self.fc2 = TensorParallelRowLinear.load( prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class MllamaVisionSdpaAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.embed_dim = config.hidden_size self.head_dim = config.hidden_size // config.attention_heads self.num_heads = config.attention_heads // weights.process_group.size() self.qkv_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_state) query, key, value = qkv.split( [ self.head_dim * self.num_heads, self.head_dim * self.num_heads, self.head_dim * self.num_heads, ], dim=2, ) batch_size, q_seq_len, _ = query.shape _, kv_seq_len, _ = key.shape query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) output = self.o_proj(attn_output) return output class MllamaVisionEncoderLayer(nn.Module): def __init__(self, *, prefix, config, weights, is_gated: bool): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.attention_heads self.is_gated = is_gated self.intermediate_size = config.intermediate_size self.self_attn = MllamaVisionSdpaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.mlp = MllamaVisionMLP( prefix=f"{prefix}.mlp", config=config, weights=weights ) self.input_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 ) self.post_attention_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 ) # there used to be an if else here, no code path if is_gated: self.gate_attn = nn.Parameter( weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False ) self.gate_ffn = nn.Parameter( weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() hidden_state = residual + gate_ffn * hidden_state return hidden_state class MllamaVisionEncoder(nn.Module): def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): super().__init__() self.config = config self.layers = [ MllamaVisionEncoderLayer( prefix=f"{prefix}.layers.{i}", config=config, weights=weights, is_gated=is_gated, ) for i in range(num_layers) ] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): encoder_states = [hidden_states] for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, ) hidden_states = layer_outputs encoder_states.append(hidden_states) return hidden_states, encoder_states class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.max_aspect_ratio_id = config.max_aspect_ratio_id self.embedding = TensorParallelEmbedding( prefix=f"{prefix}.embedding", weights=weights ) self.gate = nn.Parameter( weights.get_tensor(f"{prefix}.gate"), requires_grad=False ) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) # Always gated. embeddings = embeddings * self.gate.tanh() hidden_state = hidden_state + embeddings return hidden_state class MllamaPrecomputedPositionEmbedding(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 self.gate = nn.Parameter( weights.get_tensor(f"{prefix}.gate"), requires_grad=False ) # position embedding embedding = nn.Parameter( weights.get_tensor(f"{prefix}.embedding"), requires_grad=False ) self.gated_position_embedding = (1 - self.gate.tanh()) * embedding self.tile_embedding = TensorParallelEmbedding( prefix=f"{prefix}.tile_embedding", weights=weights ) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: # position embeddings hidden_state = hidden_state + self.gated_position_embedding.view( 1, 1, self.num_patches, self.hidden_size ) # precomputed tile position embeddings tile_position_embedding = self.tile_embedding(aspect_ratio_ids) batch_size = hidden_state.shape[0] tile_position_embedding = tile_position_embedding.reshape( batch_size, self.max_num_tiles, self.num_patches, self.hidden_size ) gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding hidden_state = hidden_state + gated_tile_position_embedding return hidden_state class MllamaVisionModel(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.num_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.dtype = weights.dtype self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", bias=False, ) self.patch_embedding.weight = nn.Parameter( weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False ) self.class_embedding = nn.Parameter( weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False ) self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( prefix=f"{prefix}.gated_positional_embedding", config=config, weights=weights, ) self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( prefix=f"{prefix}.pre_tile_positional_embedding", config=config, weights=weights, ) self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( prefix=f"{prefix}.post_tile_positional_embedding", config=config, weights=weights, ) ## layer norms self.layernorm_pre = nn.LayerNorm.load( prefix=f"{prefix}.layernorm_pre", weights=weights, # torch default eps=1e-05, ) self.layernorm_post = nn.LayerNorm.load( prefix=f"{prefix}.layernorm_post", weights=weights, # torch default eps=1e-05, ) ## encoders self.transformer = MllamaVisionEncoder( prefix=f"{prefix}.transformer", config=config, weights=weights, is_gated=False, num_layers=config.num_hidden_layers, ) self.global_transformer = MllamaVisionEncoder( prefix=f"{prefix}.global_transformer", config=config, weights=weights, is_gated=True, num_layers=config.num_global_layers, ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size, _, hidden_size = hidden_state.shape class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state def forward( self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: ( batch_size, num_concurrent_media, num_tiles, num_channels, height, width, ) = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width ) aspect_ratio_ids = aspect_ratio_ids.reshape( batch_size * num_concurrent_media, -1 ) # patch embedding patch_embeds = self.patch_embedding(pixel_values) hidden_state = patch_embeds.flatten(2).transpose(1, 2) # tile embeddings _, num_patches, dim = hidden_state.shape hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, -1, dim ) hidden_state = self.pre_tile_positional_embedding( hidden_state, aspect_ratio_ids ) # apply cls token hidden_state = hidden_state.reshape( batch_size * num_concurrent_media * num_tiles, num_patches, dim ) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # apply position embeddings hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches, dim ) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) # apply encoder hidden_state = self.layernorm_pre(hidden_state) # Compute the number of tokens to pad num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 # Compute padding tuple for pad function padding = ( 0, 0, 0, num_padding_patches, ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None if attention_mask is not None: attention_mask = attention_mask.reshape( batch_size * num_concurrent_media, -1 ) attention_mask = _prepare_aspect_ratio_attention_mask( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, target_length=hidden_state.shape[2], dtype=self.dtype, ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) hidden_state, all_intermediate_hidden_states = self.transformer( hidden_state, attention_mask=attention_mask, ) intermediate_hidden_states = [ hidden_state for idx, hidden_state in enumerate(all_intermediate_hidden_states) if idx in self.intermediate_layers_indices ] intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = self.post_tile_positional_embedding( hidden_state, aspect_ratio_ids ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim, ) hidden_state, _ = self.global_transformer( hidden_state, attention_mask=attention_mask ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = hidden_state[:, :, :slice_index] # adding intermediate layer outputs hidden_state = hidden_state.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, dim ) intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1, ) intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, -1 ) hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) return hidden_state class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.num_heads = self.config.num_attention_heads self.num_key_value_heads = self.config.num_key_value_heads self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_size = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.layer_idx = layer_idx self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( self.num_key_value_heads // weights.process_group.size() ) self.q_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.q_proj", weights=weights, bias=False, ) self.k_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.k_proj", weights=weights, bias=False, ) self.v_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.v_proj", weights=weights, bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) self.q_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps ) self.k_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps ) self.softmax_scale = self.head_size**-0.5 def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, # past_key_value=None, # attention_mask: Optional[torch.Tensor] = None, # cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # hidden_states = hidden_states.unsqueeze(0) # bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view(-1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, max_q, max_k, indices, ) = cross_attention_states key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) key_states = self.k_norm(key_states) # key_states = key_states.repeat(1, self.num_key_value_groups, 1) # value_states = value_states.repeat(1, self.num_key_value_groups, 1) causal = False # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) if SYSTEM == "ipex": attn_output = torch.empty_like(query_states) ipex.llm.functional.varlen_attention( ( query_states.contiguous() if query_states.device.type == "xpu" else query_states ), ( key_states.contiguous() if key_states.device.type == "xpu" else key_states ), ( value_states.contiguous() if value_states.device.type == "xpu" else value_states ), attn_output, cu_seqlen_q, cu_seqlen_k, max_q, max_k, 0.0, self.softmax_scale, False, causal, False, None, ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query_states, key_states, value_states, None, cu_seqlen_q, cu_seqlen_k, None, None, None, # block_tables None, max_q, max_k, 0.0, self.softmax_scale, False, causal, # Causal -1, # window_size_left, -1, 0.0, # softcap False, None, )[0] attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return attn_output # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText class MllamaTextMLP(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): shape = x.shape gate_up_states = self.gate_up_proj(x) gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) result = self.down_proj( self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] ) return result class FlashLlamaCrossLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, *, prefix, config, weights, index) -> None: layer_idx = index super().__init__() self.cross_attn = MllamaTextCrossAttention( prefix=f"{prefix}.cross_attn", config=config, weights=weights, layer_idx=layer_idx, ) self.input_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.cross_attn_attn_gate = torch.nn.Parameter( weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False ) self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) self.cross_attn_mlp_gate = torch.nn.Parameter( weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False ) self.layer_idx = layer_idx def forward( self, hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, adapter_data, cross_attention_states, # [ IB, ...] ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: return hidden_states, residual if residual is not None: hidden_states += residual indices = cross_attention_states[-1] out_hidden_states = hidden_states[:] if len(indices) > 0: assert max(indices) < hidden_states.shape[0] hidden_states = hidden_states[indices] residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.cross_attn( hidden_states=hidden_states, # attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states out_hidden_states[indices] = hidden_states hidden_states = out_hidden_states return hidden_states, None # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class MllamaTextRMSNorm(nn.Module): def __init__(self, weight, eps): super().__init__() self.weight = weight self.variance_epsilon = eps @classmethod def load(cls, *, prefix, weights, eps): weight = nn.Parameter( weights.get_tensor(f"{prefix}.weight"), requires_grad=False ) return cls(weight=weight, eps=eps) def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class MllamaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = None config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator config.text_config._attn_implementation = "sdpa" self.hidden_size = config.text_config.hidden_size self.vision_model = MllamaVisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) self.multi_modal_projector = FastLinear.load( prefix="multi_modal_projector", config=config, weights=weights, bias=True ) self.text_model = FlashLlamaForCausalLM( prefix="language_model", config=config.text_config, weights=weights ) self.config = config self.dtype = weights.dtype self.device = weights.device def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): if aspect_ratio_ids is None: raise ValueError( "`aspect_ratio_ids` must be provided if `pixel_values` is provided" ) # logger.info(f"PIxel values {pixel_values.shape}") batch_size = pixel_values.shape[0] vision_states = self.vision_model( pixel_values, aspect_ratio_ids, aspect_ratio_mask ) cross_attention_states = self.multi_modal_projector(vision_states).reshape( -1, vision_states.shape[-2], self.hidden_size ) _, _, h = cross_attention_states.shape cross_attention_states = cross_attention_states.view(batch_size, -1, h) # logger.info(f"cross {cross_attention_states.shape}") return cross_attention_states def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): if cross_attention_states is not None: seqlen_q = len(image_indices) n_images = cross_attention_states.shape[0] seqlen_k = cross_attention_states.shape[1] device = cross_attention_states.device if cu_seqlen_prefill is not None: offset = 0 cu_q = [] indices = [] for index in image_indices: cu_q.append(offset) length = seqlen.input_lengths[index].item() assert index < seqlen.cu_seqlen_q.shape[0] input_ids_offset = seqlen.cu_seqlen_q[index] indices.extend(range(input_ids_offset, input_ids_offset + length)) offset += length cu_q.append(offset) cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) assert max(indices) < input_ids.shape[0] cu_seqlen_k = ( torch.arange( n_images + 1, device=device, dtype=torch.int32, ) * seqlen_k ) max_q = cu_seqlen_q[-1].item() max_k = seqlen_k else: cu_seqlen_q = torch.arange( seqlen_q + 1, device=device, dtype=torch.int32 ) seqlen_k = cross_attention_states.shape[1] n_images = cross_attention_states.shape[0] cu_seqlen_k = ( torch.arange( n_images + 1, device=device, dtype=torch.int32, ) * seqlen_k ) max_q = seqlen_q max_k = seqlen_k indices = image_indices[:] cross_attention_states = ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, max_q, max_k, indices, ) outputs = self.text_model( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, seqlen=seqlen, max_s=max_s, prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, ) return outputs