2024-09-18 15:01:36 +00:00
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mllama model. """
2024-09-18 15:59:13 +00:00
from typing import Optional , Tuple
2024-09-18 15:01:36 +00:00
import torch
import torch . utils . checkpoint
from torch import nn
import math
from transformers . activations import ACT2FN
import torch . nn . functional as F
from text_generation_server . layers . layernorm import (
FastRMSNorm ,
)
from text_generation_server . layers . rotary import PositionRotaryEmbedding
from text_generation_server . layers import (
TensorParallelColumnLinear ,
TensorParallelEmbedding ,
TensorParallelRowLinear ,
SpeculativeHead ,
FastLinear ,
)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . activation_fn = ACT2FN [ config . hidden_act ]
self . fc1 = TensorParallelColumnLinear . load (
prefix = f " { prefix } .fc1 " , weights = weights , config = config , bias = True
)
self . fc2 = TensorParallelRowLinear . load (
prefix = f " { prefix } .fc2 " , weights = weights , config = config , bias = True
)
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
hidden_states = self . fc1 ( hidden_states )
hidden_states = self . activation_fn ( hidden_states )
hidden_states = self . fc2 ( hidden_states )
return hidden_states
class MllamaVisionSdpaAttention ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . embed_dim = config . hidden_size
self . num_heads = config . attention_heads
self . head_dim = config . hidden_size / / config . attention_heads
self . qkv_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
def forward (
self ,
hidden_state : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
qkv = self . qkv_proj ( hidden_state )
query , key , value = qkv . split (
[
self . head_size * self . num_heads ,
self . head_size * self . num_heads ,
self . head_size * self . num_heads ,
] ,
dim = 1 ,
)
batch_size , q_seq_len , _ = query . shape
_ , kv_seq_len , _ = key . shape
query = query . view ( batch_size , q_seq_len , self . num_heads , self . head_dim )
key = key . view ( batch_size , kv_seq_len , self . num_heads , self . head_dim )
value = value . view ( batch_size , kv_seq_len , self . num_heads , self . head_dim )
query = query . transpose ( 1 , 2 )
key = key . transpose ( 1 , 2 )
value = value . transpose ( 1 , 2 )
attn_output = F . scaled_dot_product_attention (
query , key , value , attn_mask = attention_mask
)
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . reshape ( batch_size , q_seq_len , - 1 )
output = self . o_proj ( attn_output )
return output
class MllamaVisionEncoderLayer ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights , is_gated : bool ) :
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . num_attention_heads = config . attention_heads
self . is_gated = is_gated
self . intermediate_size = config . intermediate_size
self . self_attn = MllamaVisionSdpaAttention (
prefix = f " { prefix } .self_attn " , config = config , weights = weights
)
self . mlp = MllamaVisionMLP (
prefix = f " { prefix } .mlp " , config = config , weights = weights
)
self . input_layernorm = nn . LayerNorm . load (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = 1e-05
)
self . post_attention_layernorm = nn . LayerNorm . load (
prefix = f " { prefix } .post_attention_layernorm " , weights = weights , eps = 1e-05
)
# there used to be an if else here, no code path
if is_gated :
self . gate_attn = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate_attn " ) , requires_grad = False
)
self . gate_ffn = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate_attn " ) , requires_grad = False
)
def forward (
self ,
hidden_state : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) :
# Self Attention
residual = hidden_state
hidden_state = self . input_layernorm ( hidden_state )
hidden_state , attn_weights = self . self_attn (
hidden_state , attention_mask = attention_mask
)
gate_attn = 1 if not self . is_gated else self . gate_attn . tanh ( )
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self . post_attention_layernorm ( hidden_state )
hidden_state = self . mlp ( hidden_state )
gate_ffn = 1 if not self . is_gated else self . gate_ffn . tanh ( )
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights , is_gated : bool , num_layers : int ) :
super ( ) . __init__ ( )
self . config = config
self . layers = [
MllamaVisionEncoderLayer (
prefix = f " { prefix } .layers. { i } " ,
config = config ,
weights = weights ,
is_gated = is_gated ,
)
for i in range ( num_layers )
]
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) :
for encoder_layer in self . layers :
layer_outputs = encoder_layer (
hidden_states ,
attention_mask ,
)
hidden_states = layer_outputs [ 0 ]
return hidden_states
class MllamaPrecomputedAspectRatioEmbedding ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . max_num_tiles = config . max_num_tiles
self . hidden_size = config . hidden_size
self . max_aspect_ratio_id = config . max_aspect_ratio_id
self . embedding = TensorParallelEmbedding (
prefix = f " { prefix } .embedding " , weights = weights
)
self . gate = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate " ) , requires_grad = False
)
def forward (
self , hidden_state : torch . Tensor , aspect_ratio_ids : torch . Tensor
) - > torch . Tensor :
embeddings = self . embedding ( aspect_ratio_ids )
embeddings = embeddings . reshape ( - 1 , self . max_num_tiles , 1 , self . hidden_size )
# Always gated.
embeddings = embeddings * self . gate . tanh ( )
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . max_num_tiles = config . max_num_tiles
self . max_aspect_ratio_id = config . max_aspect_ratio_id
self . num_patches = ( config . image_size / / config . patch_size ) * * 2 + 1
self . hidden_size = config . hidden_size
self . scale = config . hidden_size * * - 0.5
self . gate = nn . Parameter ( torch . zeros ( 1 ) )
# position embedding
self . embedding = nn . Parameter (
weights . get_tensor ( f " { prefix } .embedding " ) , requires_grad = False
)
self . tile_embedding = TensorParallelEmbedding (
prefix = f " { prefix } .tile_embedding " , weights = weights
)
def forward (
self , hidden_state : torch . Tensor , aspect_ratio_ids : torch . Tensor
) - > torch . Tensor :
# position embeddings
gated_position_embedding = ( 1 - self . gate . tanh ( ) ) * self . embedding
hidden_state = hidden_state + gated_position_embedding . view (
1 , 1 , self . num_patches , self . hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self . tile_embedding ( aspect_ratio_ids )
batch_size = hidden_state . shape [ 0 ]
tile_position_embedding = tile_position_embedding . reshape (
batch_size , self . max_num_tiles , self . num_patches , self . hidden_size
)
gated_tile_position_embedding = self . gate . tanh ( ) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . image_size = config . image_size
self . patch_size = config . patch_size
self . max_num_tiles = config . max_num_tiles
self . hidden_size = config . hidden_size
self . in_channels = config . in_channels
self . intermediate_layers_indices = config . intermediate_layers_indices
self . num_patches = ( self . image_size / / self . patch_size ) * * 2 + 1
self . scale = config . hidden_size * * - 0.5
self . patch_embedding = nn . Conv2d (
in_channels = config . in_channels ,
out_channels = self . hidden_size ,
kernel_size = self . patch_size ,
stride = self . patch_size ,
padding = " valid " ,
bias = False ,
)
self . patch_embedding . weight = nn . Parameter (
weights . get_tensor ( f " { prefix } .patch_embedding.weight " ) , requires_grad = False
)
self . class_embedding = nn . Parameter (
weights . get_tensor ( f " { prefix } .class_embedding " ) , requires_grad = False
)
self . gated_positional_embedding = MllamaPrecomputedPositionEmbedding (
prefix = f " { prefix } .gated_positional_embedding " ,
config = config ,
weights = weights ,
)
self . pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding (
prefix = f " { prefix } .pre_tile_positional_embedding " ,
config = config ,
weights = weights ,
)
self . post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding (
prefix = f " { prefix } .post_tile_positional_embedding " ,
config = config ,
weights = weights ,
)
## layer norms
self . layernorm_pre = nn . LayerNorm . load (
prefix = f " { prefix } .layernorm_pre " ,
weights = weights ,
# torch default
eps = 1e-05 ,
)
self . layernorm_post = nn . LayerNorm . load (
prefix = f " { prefix } .layernorm_post " ,
weights = weights ,
# torch default
eps = 1e-05 ,
)
## encoders
self . transformer = MllamaVisionEncoder (
prefix = f " { prefix } .transformer " ,
config = config ,
weights = weights ,
is_gated = False ,
num_layers = config . num_hidden_layers ,
)
self . global_transformer = MllamaVisionEncoder (
prefix = f " { prefix } .global_transformer " ,
config = config ,
weights = weights ,
is_gated = True ,
num_layers = config . num_global_layers ,
)
def apply_class_embedding ( self , hidden_state : torch . Tensor ) - > torch . Tensor :
batch_size , _ , hidden_size = hidden_state . shape
class_embedding = self . class_embedding . expand ( batch_size , 1 , hidden_size )
hidden_state = torch . cat ( [ class_embedding , hidden_state ] , dim = 1 )
return hidden_state
def forward (
self ,
pixel_values : torch . Tensor ,
aspect_ratio_ids : torch . Tensor ,
attention_mask : torch . Tensor ,
) - > torch . Tensor :
batch_size , num_concurrent_media , num_tiles , num_channels , height , width = (
pixel_values . shape
)
pixel_values = pixel_values . reshape (
batch_size * num_concurrent_media * num_tiles , num_channels , height , width
)
aspect_ratio_ids = aspect_ratio_ids . reshape (
batch_size * num_concurrent_media , - 1
)
# patch embedding
patch_embeds = self . patch_embedding ( pixel_values . to ( self . dtype ) . to ( self . device ) )
hidden_state = patch_embeds . flatten ( 2 ) . transpose ( 1 , 2 )
# tile embeddings
_ , num_patches , dim = hidden_state . shape
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media , num_tiles , - 1 , dim
)
hidden_state = self . pre_tile_positional_embedding (
hidden_state , aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media * num_tiles , num_patches , dim
)
hidden_state = self . apply_class_embedding ( hidden_state )
num_patches + = 1
# apply position embeddings
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media , num_tiles , num_patches , dim
)
hidden_state = self . gated_positional_embedding ( hidden_state , aspect_ratio_ids )
# apply encoder
hidden_state = self . layernorm_pre ( hidden_state )
# Compute the number of tokens to pad
num_padding_patches = ( 8 - ( hidden_state . shape [ - 2 ] % 8 ) ) % 8
# Compute padding tuple for pad function
padding = (
0 ,
0 ,
0 ,
num_padding_patches ,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F . pad ( hidden_state , padding , mode = " constant " , value = 0 )
slice_index = - num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None :
attention_mask = attention_mask . reshape (
batch_size * num_concurrent_media , - 1
)
attention_mask = _prepare_aspect_ratio_attention_mask (
aspect_ratio_mask = attention_mask ,
num_patches = self . num_patches ,
target_length = hidden_state . shape [ 2 ] ,
dtype = self . dtype ,
)
hidden_state = hidden_state . view ( batch_size * num_concurrent_media , - 1 , dim )
output = self . transformer (
hidden_state ,
attention_mask = attention_mask ,
output_hidden_states = True ,
)
hidden_state , all_intermediate_hidden_states = output [ 0 ] , output [ 1 ]
intermediate_hidden_states = [
hidden_state
for idx , hidden_state in enumerate ( all_intermediate_hidden_states )
if idx in self . intermediate_layers_indices
]
intermediate_hidden_states = torch . stack ( intermediate_hidden_states , dim = - 1 )
# apply global encoder
hidden_state = self . layernorm_post ( hidden_state )
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
dim ,
)
hidden_state = self . post_tile_positional_embedding (
hidden_state , aspect_ratio_ids
)
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles * ( num_patches + num_padding_patches ) ,
dim ,
)
hidden_state = self . global_transformer (
hidden_state , attention_mask = attention_mask
) [ 0 ]
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
dim ,
)
hidden_state = hidden_state [ : , : , : slice_index ]
# adding intermediate layer outputs
hidden_state = hidden_state . reshape (
batch_size , num_concurrent_media , num_tiles , num_patches , dim
)
intermediate_hidden_states = intermediate_hidden_states . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
- 1 ,
)
intermediate_hidden_states = intermediate_hidden_states [ : , : , : slice_index ]
intermediate_hidden_states = intermediate_hidden_states . reshape (
batch_size , num_concurrent_media , num_tiles , num_patches , - 1
)
hidden_state = torch . cat ( [ hidden_state , intermediate_hidden_states ] , dim = - 1 )
return hidden_state
class MllamaTextCrossAttention ( nn . Module ) :
""" Multi-headed attention from ' Attention Is All You Need ' paper """
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . num_heads = self . config . num_attention_heads
self . num_key_value_heads = self . config . num_key_value_heads
self . dropout = config . dropout
self . hidden_size = config . hidden_size
self . head_dim = config . hidden_size / / self . num_heads
self . num_key_value_groups = self . num_heads / / self . num_key_value_heads
self . qkv_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
self . q_norm = FastRMSNorm . load (
prefix = f " { prefix } .q_norm " , weights = weights , eps = config . rms_norm_eps
)
self . k_norm = FastRMSNorm . load (
prefix = f " { prefix } .k_norm " , weights = weights , eps = config . rms_norm_eps
)
def forward (
self ,
hidden_states : torch . Tensor ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
past_key_value = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
output_attentions : bool = False ,
use_cache : bool = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
""" Input shape: Batch x Time x Channel """
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
query_states = query_states . view (
bsz , q_len , self . num_heads , self . head_dim
) . transpose ( 1 , 2 )
query_states = self . q_norm ( query_states )
if cross_attention_states is not None :
key_states = self . k_proj ( cross_attention_states )
value_states = self . v_proj ( cross_attention_states )
key_states = key_states . view (
bsz , - 1 , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
value_states = value_states . view (
bsz , - 1 , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
key_states = self . k_norm ( key_states )
if past_key_value is not None :
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states , value_states = past_key_value . update (
key_states ,
value_states ,
self . layer_idx ,
{ " cache_position " : cache_position } ,
)
elif cache_position [ 0 ] != 0 :
key_states , value_states = (
past_key_value . key_cache [ self . layer_idx ] ,
past_key_value . value_cache [ self . layer_idx ] ,
)
else :
raise ValueError (
" Cross attention layer can ' t find neither `cross_attn_states` nor cached values for key/values! "
)
attn_weights = torch . matmul (
query_states , key_states . transpose ( 2 , 3 )
) / math . sqrt ( self . head_dim )
if attention_mask is not None : # no matter the length, we just slice it
causal_mask = attention_mask [ : , : , : , : key_states . shape [ - 2 ] ]
attn_weights = attn_weights + causal_mask
attn_weights = nn . functional . softmax (
attn_weights , dim = - 1 , dtype = torch . float32
) . to ( query_states . dtype )
attn_weights = nn . functional . dropout (
attn_weights , p = self . dropout , training = self . training
)
attn_output = torch . matmul ( attn_weights , value_states )
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . reshape ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
if not output_attentions :
attn_weights = None
return attn_output , attn_weights , past_key_value
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . hidden_size = config . hidden_size
self . intermediate_size = config . intermediate_size
self . gate_up_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .gate_proj " , f " { prefix } .up_proj " ] ,
weights = weights ,
dim = 0 ,
bias = False ,
)
self . down_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .down_proj " ,
weights = weights ,
bias = False ,
)
self . act_fn = ACT2FN [ config . hidden_activation ]
def forward ( self , x ) :
gate_up_states = self . gate_up_proj ( x )
gate_up_states = gate_up_states . view ( - 1 , 2 , self . intermediate_size )
return self . down_proj ( self . act ( gate_up_states [ : , 0 ] ) * gate_up_states [ : , 1 ] )
class MllamaCrossAttentionDecoderLayer ( torch . nn . Module ) :
""" Cross-attention transformer block with tanh-gated attention and feedforward. """
def __init__ ( self , * , prefix , config , weights ) - > None :
super ( ) . __init__ ( )
self . cross_attn = MllamaTextCrossAttention (
prefix = f " { prefix } .cross_attn " , config = config , weights = weights
)
self . input_layernorm = FastRMSNorm . load (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
self . cross_attn_attn_gate = torch . nn . Parameter (
weights . get_tensor ( f " { prefix } .cross_attn_attn_gate " ) , requires_grad = False
)
self . mlp = MllamaTextMLP ( prefix = f " { prefix } .mlp " , config = config , weights = weights )
self . post_attention_layernorm = FastRMSNorm . load (
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
self . cross_attn_mlp_gate = torch . nn . Parameter (
weights . get_tensor ( f " { prefix } .cross_attn_mlp_gate " ) , requires_grad = False
)
def forward (
self ,
hidden_states : torch . Tensor ,
cross_attention_states : torch . Tensor ,
cross_attention_mask : torch . Tensor ,
attention_mask : torch . Tensor ,
full_text_row_masked_out_mask : Tuple [ torch . Tensor , torch . Tensor ] ,
past_key_value = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
* * kwargs ,
) - > torch . Tensor :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
hidden_states , attn_weights , past_key_value = self . cross_attn (
hidden_states = hidden_states ,
attention_mask = cross_attention_mask ,
cross_attention_states = cross_attention_states ,
past_key_value = past_key_value ,
output_attentions = output_attentions ,
cache_position = cache_position ,
)
hidden_states = residual + self . cross_attn_attn_gate . tanh ( ) * hidden_states
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
if full_text_row_masked_out_mask is not None :
hidden_states = full_text_row_masked_out_mask [ : , 0 ] * hidden_states # type: ignore
hidden_states = residual + self . cross_attn_mlp_gate . tanh ( ) * hidden_states
outputs = ( hidden_states , )
if output_attentions :
outputs + = ( attn_weights , )
if use_cache :
outputs + = ( past_key_value , )
return outputs
class MllamaTextSelfAttention ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . num_heads = config . num_attention_heads
self . dropout = config . dropout
self . hidden_size = config . hidden_size
self . num_key_value_heads = config . num_key_value_heads
self . head_dim = config . hidden_size / / self . num_heads
self . num_key_value_groups = self . num_heads / / self . num_key_value_heads
self . qkv_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : torch . Tensor ,
position_embeddings : torch . Tensor ,
output_attentions : bool = False ,
use_cache : bool = False ,
past_key_value = None ,
cache_position = None ,
* * kwargs ,
) :
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
key_states = self . k_proj ( hidden_states )
value_states = self . v_proj ( hidden_states )
query_states = query_states . view (
bsz , q_len , self . num_heads , self . head_dim
) . transpose ( 1 , 2 )
key_states = key_states . view (
bsz , q_len , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
value_states = value_states . view (
bsz , q_len , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
cos , sin = position_embeddings
query_states , key_states = apply_rotary_pos_emb (
query_states , key_states , cos , sin
)
if past_key_value is not None :
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position }
key_states , value_states = past_key_value . update (
key_states , value_states , self . layer_idx , cache_kwargs
)
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
causal_mask = attention_mask
if attention_mask is not None :
causal_mask = causal_mask [ : , : , : , : key_states . shape [ - 2 ] ]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states . device . type == " cuda " and causal_mask is not None :
query_states = query_states . contiguous ( )
key_states = key_states . contiguous ( )
value_states = value_states . contiguous ( )
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch . nn . functional . scaled_dot_product_attention (
query_states ,
key_states ,
value_states ,
attn_mask = causal_mask ,
dropout_p = self . dropout if self . training else 0.0 ,
is_causal = is_causal ,
)
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . view ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
return attn_output , None , past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
class MllamaSelfAttentionDecoderLayer ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . self_attn = MllamaTextSelfAttention (
prefix = f " { prefix } .self_attn " , config = config , weights = weights
)
self . mlp = MllamaTextMLP ( prefix = f " { prefix } .mlp " , config = config , weights = weights )
self . input_layernorm = FastRMSNorm . load (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
self . post_attention_layernorm = FastRMSNorm . load (
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
position_embeddings : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None , # will become mandatory in v4.45
* * kwargs ,
) - > Tuple [
torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ]
] :
"""
Args :
hidden_states ( ` torch . FloatTensor ` ) : input to the layer of shape ` ( batch , seq_len , embed_dim ) `
attention_mask ( ` torch . FloatTensor ` , * optional * ) :
attention mask of size ` ( batch_size , sequence_length ) ` if flash attention is used or ` ( batch_size , 1 ,
query_sequence_length , key_sequence_length ) ` if default attention is used .
output_attentions ( ` bool ` , * optional * ) :
Whether or not to return the attentions tensors of all attention layers . See ` attentions ` under
returned tensors for more detail .
use_cache ( ` bool ` , * optional * ) :
If set to ` True ` , ` past_key_values ` key value states are returned and can be used to speed up decoding
( see ` past_key_values ` ) .
past_key_value ( ` Tuple ( torch . FloatTensor ) ` , * optional * ) : cached past key and value projection states
cache_position ( ` torch . LongTensor ` of shape ` ( sequence_length ) ` , * optional * ) :
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings ( ` Tuple [ torch . FloatTensor , torch . FloatTensor ] ` , * optional * ) :
Tuple containing the cosine and sine positional embeddings of shape ` ( batch_size , seq_len , head_dim ) ` ,
with ` head_dim ` being the embedding dimension of each attention head .
kwargs ( ` dict ` , * optional * ) :
Arbitrary kwargs to be ignored , used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
# Self Attention
hidden_states , self_attn_weights , present_key_value = self . self_attn (
hidden_states = hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_value = past_key_value ,
output_attentions = output_attentions ,
use_cache = use_cache ,
cache_position = cache_position ,
position_embeddings = position_embeddings ,
* * kwargs ,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = residual + hidden_states
outputs = ( hidden_states , )
if output_attentions :
outputs + = ( self_attn_weights , )
if use_cache :
outputs + = ( present_key_value , )
return outputs
class MllamaTextModel ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . padding_idx = config . pad_token_id
self . vocab_size = config . vocab_size
self . embed_tokens = TensorParallelEmbedding (
prefix = f " { prefix } .embed_tokens " , weights = weights
)
self . cross_attention_layers = config . cross_attention_layers
self . layers = [ ]
for layer_idx in range ( config . num_hidden_layers ) :
if layer_idx in self . cross_attention_layers :
self . layers . append (
MllamaCrossAttentionDecoderLayer (
prefix = f " { prefix } .layers. { layer_idx } " ,
config = config ,
weights = weights ,
)
)
else :
self . layers . append (
MllamaSelfAttentionDecoderLayer (
prefix = f " { prefix } .layers. { layer_idx } " ,
config = config ,
weights = weights ,
)
)
# TODO Should we use this slow norm ?
# self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self . norm = FastRMSNorm . load (
prefix = f " { prefix } .norm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
# TODO Anything specific ?
head_size = config . hidden_size / / config . num_attention_heads
self . rotary_emb = PositionRotaryEmbedding . static (
config = config ,
dim = head_size ,
base = config . rope_theta ,
device = weights . device ,
)
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
cross_attention_states : Optional [ torch . FloatTensor ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
full_text_row_masked_out_mask : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None ,
past_key_values = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
) :
if ( input_ids is None ) ^ ( inputs_embeds is not None ) :
raise ValueError (
" You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one "
)
if inputs_embeds is None :
inputs_embeds = self . embed_tokens ( input_ids )
hidden_states = inputs_embeds
if cache_position is None :
past_seen_tokens = (
past_key_values . get_seq_length ( ) if past_key_values is not None else 0
)
cache_position = torch . arange (
past_seen_tokens ,
past_seen_tokens + inputs_embeds . shape [ 1 ] ,
device = inputs_embeds . device ,
)
if position_ids is None :
position_ids = cache_position . unsqueeze ( 0 )
# causal_mask = self._update_causal_mask(
# attention_mask,
# inputs_embeds,
# cache_position,
# past_key_values,
# )
# create position embeddings to be shared across the decoder layers
position_embeddings = self . rotary_emb ( hidden_states , position_ids )
# decoder layers
for idx , decoder_layer in enumerate ( self . layers ) :
if (
idx in self . cross_attention_layers
and cross_attention_states is None
and (
past_key_values is None
or (
past_key_values is not None
and past_key_values . get_seq_length ( idx ) == 0
)
)
) :
continue
layer_outputs = decoder_layer (
hidden_states ,
cross_attention_states = cross_attention_states ,
cross_attention_mask = cross_attention_mask ,
attention_mask = causal_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
position_ids = position_ids ,
past_key_value = past_key_values ,
cache_position = cache_position ,
position_embeddings = position_embeddings ,
)
hidden_states = layer_outputs
hidden_states = self . norm ( hidden_states )
return hidden_states
# def _update_causal_mask(
# self,
# attention_mask: torch.Tensor,
# input_tensor: torch.Tensor,
# cache_position: torch.Tensor,
# past_key_values,
# ):
# if self.config._attn_implementation == "flash_attention_2":
# if attention_mask is not None and 0.0 in attention_mask:
# return attention_mask
# return None
# # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# # to infer the attention mask.
# past_seen_tokens = (
# past_key_values.get_seq_length() if past_key_values is not None else 0
# )
# using_static_cache = isinstance(past_key_values, StaticCache)
# # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
# # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
# # self.config._attn_implementation == "sdpa" and
# if (
# self.config._attn_implementation == "sdpa"
# and not using_static_cache
# and not output_attentions
# ):
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
# attention_mask,
# inputs_embeds=input_tensor,
# past_key_values_length=past_seen_tokens,
# is_training=self.training,
# ):
# return None
# dtype, device = input_tensor.dtype, input_tensor.device
# min_dtype = torch.finfo(dtype).min
# sequence_length = input_tensor.shape[1]
# if using_static_cache:
# target_length = past_key_values.get_max_length()
# else:
# target_length = (
# attention_mask.shape[-1]
# if isinstance(attention_mask, torch.Tensor)
# else past_seen_tokens + sequence_length + 1
# )
# # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
# attention_mask,
# sequence_length=sequence_length,
# target_length=target_length,
# dtype=dtype,
# device=device,
# min_dtype=min_dtype,
# cache_position=cache_position,
# batch_size=input_tensor.shape[0],
# )
# if (
# self.config._attn_implementation == "sdpa"
# and attention_mask is not None
# and attention_mask.device.type == "cuda"
# and not output_attentions
# ):
# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# # Details: https://github.com/pytorch/pytorch/issues/110213
# causal_mask = AttentionMaskConverter._unmask_unattended(
# causal_mask, min_dtype
# )
# return causal_mask
class MllamaForCausalLM ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . vocab_size = config . vocab_size
self . model = MllamaTextModel (
prefix = f " { prefix } .model " , config = config , weights = weights
)
self . lm_head = SpeculativeHead . load (
prefix = f " { prefix } .lm_head " ,
config = config ,
weights = weights ,
)
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
cross_attention_states : Optional [ torch . LongTensor ] = None ,
cross_attention_mask : Optional [ torch . LongTensor ] = None ,
full_text_row_masked_out_mask : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None ,
past_key_values = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
num_logits_to_keep : int = 0 ,
) :
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self . model (
input_ids = input_ids ,
cross_attention_states = cross_attention_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
cross_attention_mask = cross_attention_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
past_key_values = past_key_values ,
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
cache_position = cache_position ,
)
hidden_states = outputs
# if lm_head_indices is not None:
# hidden_states = hidden_states[lm_head_indices]
logits , speculative_logits = self . lm_head ( hidden_states )
return logits
def prepare_inputs_for_generation (
self ,
input_ids ,
past_key_values = None ,
attention_mask = None ,
inputs_embeds = None ,
cache_position = None ,
position_ids = None ,
use_cache = True ,
num_logits_to_keep = None ,
* * kwargs ,
) :
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None :
if inputs_embeds is not None : # Exception 1
input_ids = input_ids [ : , - cache_position . shape [ 0 ] : ]
elif (
input_ids . shape [ 1 ] != cache_position . shape [ 0 ]
) : # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids [ : , cache_position ]
if attention_mask is not None and position_ids is None :
# create position_ids on the fly for batch generation
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
if past_key_values :
position_ids = position_ids [ : , - input_ids . shape [ 1 ] : ]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids . clone ( memory_format = torch . contiguous_format )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position [ 0 ] == 0 :
model_inputs = { " inputs_embeds " : inputs_embeds , " input_ids " : None }
else :
# The clone here is for the same reason as for `position_ids`.
model_inputs = {
" input_ids " : input_ids . clone ( memory_format = torch . contiguous_format ) ,
" inputs_embeds " : None ,
}
if isinstance ( past_key_values , StaticCache ) and attention_mask . ndim == 2 :
if model_inputs [ " inputs_embeds " ] is not None :
batch_size , sequence_length , _ = model_inputs [ " inputs_embeds " ] . shape
device = model_inputs [ " inputs_embeds " ] . device
else :
batch_size , sequence_length = model_inputs [ " input_ids " ] . shape
device = model_inputs [ " input_ids " ] . device
dtype = self . lm_head . weight . dtype
min_dtype = torch . finfo ( dtype ) . min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position (
attention_mask ,
sequence_length = sequence_length ,
target_length = past_key_values . get_max_length ( ) ,
dtype = dtype ,
device = device ,
min_dtype = min_dtype ,
cache_position = cache_position ,
batch_size = batch_size ,
)
if num_logits_to_keep is not None :
model_inputs [ " num_logits_to_keep " ] = num_logits_to_keep
model_inputs . update (
{
" position_ids " : position_ids ,
" cache_position " : cache_position ,
" past_key_values " : past_key_values ,
" use_cache " : use_cache ,
" attention_mask " : attention_mask ,
}
)
return model_inputs
class MllamaForConditionalGeneration ( nn . Module ) :
def __init__ ( self , prefix , config , weights ) :
super ( ) . __init__ ( )
config . vision_config . quantize = None
config . vision_config . speculator = config . speculator
config . text_config . quantize = config . quantize
config . text_config . speculator = config . speculator
self . vision_model = MllamaVisionModel (
prefix = " vision_model " , config = config . vision_config , weights = weights
)
self . language_model = MllamaForCausalLM (
prefix = " language_model " , config = config . text_config , weights = weights
)
self . multi_modal_projector = FastLinear . load (
prefix = " multi_modal_projector " , config = config , weights = weights , bias = True
)
self . config = config