2024-09-18 15:01:36 +00:00
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mllama model. """
2024-09-23 13:25:26 +00:00
from typing import Optional , Tuple , List
2024-09-18 15:01:36 +00:00
import torch
import torch . utils . checkpoint
from torch import nn
import math
from transformers . activations import ACT2FN
2024-09-23 13:25:26 +00:00
from transformers . modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers . modeling_outputs import (
CausalLMOutputWithPast ,
BaseModelOutputWithPast ,
)
from transformers . cache_utils import (
StaticCache ,
DynamicCache ,
)
from transformers . modeling_attn_mask_utils import AttentionMaskConverter
2024-09-18 15:01:36 +00:00
import torch . nn . functional as F
from text_generation_server . layers import (
TensorParallelColumnLinear ,
TensorParallelEmbedding ,
TensorParallelRowLinear ,
SpeculativeHead ,
FastLinear ,
)
2024-09-23 13:25:26 +00:00
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb ( q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ) :
""" Applies Rotary Position Embedding to the query and key tensors.
Args :
q ( ` torch . Tensor ` ) : The query tensor .
k ( ` torch . Tensor ` ) : The key tensor .
cos ( ` torch . Tensor ` ) : The cosine part of the rotary embedding .
sin ( ` torch . Tensor ` ) : The sine part of the rotary embedding .
position_ids ( ` torch . Tensor ` , * optional * ) :
Deprecated and unused .
unsqueeze_dim ( ` int ` , * optional * , defaults to 1 ) :
The ' unsqueeze_dim ' argument specifies the dimension along which to unsqueeze cos [ position_ids ] and
sin [ position_ids ] so that they can be properly broadcasted to the dimensions of q and k . For example , note
that cos [ position_ids ] and sin [ position_ids ] have the shape [ batch_size , seq_len , head_dim ] . Then , if q and
k have the shape [ batch_size , heads , seq_len , head_dim ] , then setting unsqueeze_dim = 1 makes
cos [ position_ids ] and sin [ position_ids ] broadcastable to the shapes of q and k . Similarly , if q and k have
the shape [ batch_size , seq_len , heads , head_dim ] , then set unsqueeze_dim = 2.
Returns :
` tuple ( torch . Tensor ) ` comprising of the query and key tensors rotated using the Rotary Position Embedding .
"""
cos = cos . unsqueeze ( unsqueeze_dim )
sin = sin . unsqueeze ( unsqueeze_dim )
q_embed = ( q * cos ) + ( rotate_half ( q ) * sin )
k_embed = ( k * cos ) + ( rotate_half ( k ) * sin )
return q_embed , k_embed
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half ( x ) :
""" Rotates half the hidden dims of the input. """
x1 = x [ . . . , : x . shape [ - 1 ] / / 2 ]
x2 = x [ . . . , x . shape [ - 1 ] / / 2 : ]
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv ( hidden_states : torch . Tensor , n_rep : int ) - > torch . Tensor :
"""
This is the equivalent of torch . repeat_interleave ( x , dim = 1 , repeats = n_rep ) . The hidden states go from ( batch ,
num_key_value_heads , seqlen , head_dim ) to ( batch , num_attention_heads , seqlen , head_dim )
"""
batch , num_key_value_heads , slen , head_dim = hidden_states . shape
if n_rep == 1 :
return hidden_states
hidden_states = hidden_states [ : , : , None , : , : ] . expand (
batch , num_key_value_heads , n_rep , slen , head_dim
)
return hidden_states . reshape ( batch , num_key_value_heads * n_rep , slen , head_dim )
def _prepare_aspect_ratio_attention_mask (
aspect_ratio_mask : torch . Tensor ,
num_patches : int ,
target_length : int ,
dtype : torch . dtype ,
) - > torch . Tensor :
# Expand aspect ratio mask to target_length
batch_size , max_num_tiles = aspect_ratio_mask . shape
attention_mask = aspect_ratio_mask . view ( batch_size , max_num_tiles , 1 , 1 ) . to ( dtype )
attention_mask = attention_mask . repeat ( 1 , 1 , target_length , 1 )
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask [ : , : , - pad_patches : ] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask . reshape (
batch_size , max_num_tiles * target_length , 1
)
attention_mask = (
attention_mask @ attention_mask . transpose ( - 1 , - 2 ) * torch . finfo ( dtype ) . min
)
attention_mask = attention_mask . unsqueeze ( 1 )
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position (
attention_mask : torch . Tensor ,
sequence_length : int ,
target_length : int ,
dtype : torch . dtype ,
device : torch . device ,
min_dtype : float ,
cache_position : torch . Tensor ,
batch_size : int ,
) :
"""
Creates a causal 4 D mask of shape ` ( batch_size , 1 , query_length , key_value_length ) ` from a 2 D mask of shape
` ( batch_size , key_value_length ) ` , or if the input ` attention_mask ` is already 4 D , do nothing .
Args :
attention_mask ( ` torch . Tensor ` ) :
A 2 D attention mask of shape ` ( batch_size , key_value_length ) ` or a 4 D attention mask of shape ` ( batch_size , 1 , query_length , key_value_length ) ` .
sequence_length ( ` int ` ) :
The sequence length being processed .
target_length ( ` int ` ) :
The target length : when generating with static cache , the mask should be as long as the static cache , to account for the 0 padding , the part of the cache that is not filled yet .
dtype ( ` torch . dtype ` ) :
The dtype to use for the 4 D attention mask .
device ( ` torch . device ` ) :
The device to plcae the 4 D attention mask on .
min_dtype ( ` float ` ) :
The minimum value representable with the dtype ` dtype ` .
cache_position ( ` torch . Tensor ` ) :
Indices depicting the position of the input sequence tokens in the sequence .
batch_size ( ` torch . Tensor ` ) :
Batch size .
"""
if attention_mask is not None and attention_mask . dim ( ) == 4 :
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else :
causal_mask = torch . full (
( sequence_length , target_length ) ,
fill_value = min_dtype ,
dtype = dtype ,
device = device ,
)
if sequence_length != 1 :
causal_mask = torch . triu ( causal_mask , diagonal = 1 )
causal_mask * = torch . arange (
target_length , device = device
) > cache_position . reshape ( - 1 , 1 )
causal_mask = causal_mask [ None , None , : , : ] . expand ( batch_size , 1 , - 1 , - 1 )
if attention_mask is not None :
causal_mask = (
causal_mask . clone ( )
) # copy to contiguous memory for in-place edit
mask_length = attention_mask . shape [ - 1 ]
padding_mask = (
causal_mask [ : , : , : , : mask_length ] + attention_mask [ : , None , None , : ]
)
padding_mask = padding_mask == 0
causal_mask [ : , : , : , : mask_length ] = causal_mask [
: , : , : , : mask_length
] . masked_fill ( padding_mask , min_dtype )
return causal_mask
def _prepare_cross_attention_mask (
cross_attention_mask : torch . Tensor ,
num_vision_tokens : int ,
dtype : str ,
) - > Tuple [ torch . Tensor , torch . Tensor ] :
# reshape so it can be used by attn module
batch_size , text_total_length , * _ = cross_attention_mask . shape
cross_attention_mask = cross_attention_mask . repeat_interleave (
num_vision_tokens , dim = 3
)
cross_attention_mask = cross_attention_mask . view ( batch_size , text_total_length , - 1 )
cross_attention_mask = cross_attention_mask . unsqueeze ( 1 )
# invert the mask
inverted_cross_attn_mask = ( 1.0 - cross_attention_mask ) . to ( dtype )
cross_attention_mask = inverted_cross_attn_mask . masked_fill (
inverted_cross_attn_mask . to ( torch . bool ) , torch . finfo ( dtype ) . min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch . finfo ( dtype ) . min
full_text_row_masked_out_mask = (
( cross_attention_mask != negative_inf_value )
. any ( dim = - 1 )
. type_as ( cross_attention_mask ) [ . . . , None ]
)
cross_attention_mask * = full_text_row_masked_out_mask
return cross_attention_mask , full_text_row_masked_out_mask
2024-09-18 15:01:36 +00:00
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . activation_fn = ACT2FN [ config . hidden_act ]
self . fc1 = TensorParallelColumnLinear . load (
prefix = f " { prefix } .fc1 " , weights = weights , config = config , bias = True
)
self . fc2 = TensorParallelRowLinear . load (
prefix = f " { prefix } .fc2 " , weights = weights , config = config , bias = True
)
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
hidden_states = self . fc1 ( hidden_states )
hidden_states = self . activation_fn ( hidden_states )
hidden_states = self . fc2 ( hidden_states )
return hidden_states
class MllamaVisionSdpaAttention ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . embed_dim = config . hidden_size
self . head_dim = config . hidden_size / / config . attention_heads
2024-09-23 13:25:26 +00:00
self . num_heads = config . attention_heads / / weights . process_group . size ( )
2024-09-18 15:01:36 +00:00
self . qkv_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
def forward (
self ,
hidden_state : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
qkv = self . qkv_proj ( hidden_state )
query , key , value = qkv . split (
[
2024-09-23 13:25:26 +00:00
self . head_dim * self . num_heads ,
self . head_dim * self . num_heads ,
self . head_dim * self . num_heads ,
2024-09-18 15:01:36 +00:00
] ,
2024-09-23 13:25:26 +00:00
dim = 2 ,
2024-09-18 15:01:36 +00:00
)
batch_size , q_seq_len , _ = query . shape
_ , kv_seq_len , _ = key . shape
query = query . view ( batch_size , q_seq_len , self . num_heads , self . head_dim )
key = key . view ( batch_size , kv_seq_len , self . num_heads , self . head_dim )
value = value . view ( batch_size , kv_seq_len , self . num_heads , self . head_dim )
query = query . transpose ( 1 , 2 )
key = key . transpose ( 1 , 2 )
value = value . transpose ( 1 , 2 )
attn_output = F . scaled_dot_product_attention (
query , key , value , attn_mask = attention_mask
)
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . reshape ( batch_size , q_seq_len , - 1 )
output = self . o_proj ( attn_output )
return output
class MllamaVisionEncoderLayer ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights , is_gated : bool ) :
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . num_attention_heads = config . attention_heads
self . is_gated = is_gated
self . intermediate_size = config . intermediate_size
self . self_attn = MllamaVisionSdpaAttention (
prefix = f " { prefix } .self_attn " , config = config , weights = weights
)
self . mlp = MllamaVisionMLP (
prefix = f " { prefix } .mlp " , config = config , weights = weights
)
self . input_layernorm = nn . LayerNorm . load (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = 1e-05
)
self . post_attention_layernorm = nn . LayerNorm . load (
prefix = f " { prefix } .post_attention_layernorm " , weights = weights , eps = 1e-05
)
# there used to be an if else here, no code path
if is_gated :
self . gate_attn = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate_attn " ) , requires_grad = False
)
self . gate_ffn = nn . Parameter (
2024-09-23 13:25:26 +00:00
weights . get_tensor ( f " { prefix } .gate_ffn " ) , requires_grad = False
2024-09-18 15:01:36 +00:00
)
def forward (
self ,
hidden_state : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) :
# Self Attention
residual = hidden_state
hidden_state = self . input_layernorm ( hidden_state )
2024-09-23 13:25:26 +00:00
hidden_state = self . self_attn ( hidden_state , attention_mask = attention_mask )
2024-09-18 15:01:36 +00:00
gate_attn = 1 if not self . is_gated else self . gate_attn . tanh ( )
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self . post_attention_layernorm ( hidden_state )
hidden_state = self . mlp ( hidden_state )
gate_ffn = 1 if not self . is_gated else self . gate_ffn . tanh ( )
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights , is_gated : bool , num_layers : int ) :
super ( ) . __init__ ( )
self . config = config
self . layers = [
MllamaVisionEncoderLayer (
prefix = f " { prefix } .layers. { i } " ,
config = config ,
weights = weights ,
is_gated = is_gated ,
)
for i in range ( num_layers )
]
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
) :
2024-09-23 13:25:26 +00:00
encoder_states = [ hidden_states ]
2024-09-18 15:01:36 +00:00
for encoder_layer in self . layers :
layer_outputs = encoder_layer (
hidden_states ,
attention_mask ,
)
2024-09-23 13:25:26 +00:00
hidden_states = layer_outputs
encoder_states . append ( hidden_states )
2024-09-18 15:01:36 +00:00
2024-09-23 13:25:26 +00:00
return hidden_states , encoder_states
2024-09-18 15:01:36 +00:00
class MllamaPrecomputedAspectRatioEmbedding ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . max_num_tiles = config . max_num_tiles
self . hidden_size = config . hidden_size
self . max_aspect_ratio_id = config . max_aspect_ratio_id
self . embedding = TensorParallelEmbedding (
prefix = f " { prefix } .embedding " , weights = weights
)
self . gate = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate " ) , requires_grad = False
)
def forward (
self , hidden_state : torch . Tensor , aspect_ratio_ids : torch . Tensor
) - > torch . Tensor :
embeddings = self . embedding ( aspect_ratio_ids )
embeddings = embeddings . reshape ( - 1 , self . max_num_tiles , 1 , self . hidden_size )
# Always gated.
embeddings = embeddings * self . gate . tanh ( )
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . max_num_tiles = config . max_num_tiles
self . max_aspect_ratio_id = config . max_aspect_ratio_id
self . num_patches = ( config . image_size / / config . patch_size ) * * 2 + 1
self . hidden_size = config . hidden_size
self . scale = config . hidden_size * * - 0.5
2024-09-23 13:25:26 +00:00
self . gate = nn . Parameter (
weights . get_tensor ( f " { prefix } .gate " ) , requires_grad = False
)
2024-09-18 15:01:36 +00:00
# position embedding
2024-09-23 13:25:26 +00:00
embedding = nn . Parameter (
2024-09-18 15:01:36 +00:00
weights . get_tensor ( f " { prefix } .embedding " ) , requires_grad = False
)
2024-09-23 13:25:26 +00:00
self . gated_position_embedding = ( 1 - self . gate . tanh ( ) ) * embedding
2024-09-18 15:01:36 +00:00
self . tile_embedding = TensorParallelEmbedding (
prefix = f " { prefix } .tile_embedding " , weights = weights
)
def forward (
self , hidden_state : torch . Tensor , aspect_ratio_ids : torch . Tensor
) - > torch . Tensor :
# position embeddings
2024-09-23 13:25:26 +00:00
hidden_state = hidden_state + self . gated_position_embedding . view (
2024-09-18 15:01:36 +00:00
1 , 1 , self . num_patches , self . hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self . tile_embedding ( aspect_ratio_ids )
batch_size = hidden_state . shape [ 0 ]
tile_position_embedding = tile_position_embedding . reshape (
batch_size , self . max_num_tiles , self . num_patches , self . hidden_size
)
gated_tile_position_embedding = self . gate . tanh ( ) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . image_size = config . image_size
self . patch_size = config . patch_size
self . max_num_tiles = config . max_num_tiles
self . hidden_size = config . hidden_size
self . in_channels = config . in_channels
self . intermediate_layers_indices = config . intermediate_layers_indices
self . num_patches = ( self . image_size / / self . patch_size ) * * 2 + 1
self . scale = config . hidden_size * * - 0.5
2024-09-23 13:25:26 +00:00
self . dtype = weights . dtype
2024-09-18 15:01:36 +00:00
self . patch_embedding = nn . Conv2d (
in_channels = config . in_channels ,
out_channels = self . hidden_size ,
kernel_size = self . patch_size ,
stride = self . patch_size ,
padding = " valid " ,
bias = False ,
)
self . patch_embedding . weight = nn . Parameter (
weights . get_tensor ( f " { prefix } .patch_embedding.weight " ) , requires_grad = False
)
self . class_embedding = nn . Parameter (
weights . get_tensor ( f " { prefix } .class_embedding " ) , requires_grad = False
)
self . gated_positional_embedding = MllamaPrecomputedPositionEmbedding (
prefix = f " { prefix } .gated_positional_embedding " ,
config = config ,
weights = weights ,
)
self . pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding (
prefix = f " { prefix } .pre_tile_positional_embedding " ,
config = config ,
weights = weights ,
)
self . post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding (
prefix = f " { prefix } .post_tile_positional_embedding " ,
config = config ,
weights = weights ,
)
## layer norms
self . layernorm_pre = nn . LayerNorm . load (
prefix = f " { prefix } .layernorm_pre " ,
weights = weights ,
# torch default
eps = 1e-05 ,
)
self . layernorm_post = nn . LayerNorm . load (
prefix = f " { prefix } .layernorm_post " ,
weights = weights ,
# torch default
eps = 1e-05 ,
)
## encoders
self . transformer = MllamaVisionEncoder (
prefix = f " { prefix } .transformer " ,
config = config ,
weights = weights ,
is_gated = False ,
num_layers = config . num_hidden_layers ,
)
self . global_transformer = MllamaVisionEncoder (
prefix = f " { prefix } .global_transformer " ,
config = config ,
weights = weights ,
is_gated = True ,
num_layers = config . num_global_layers ,
)
def apply_class_embedding ( self , hidden_state : torch . Tensor ) - > torch . Tensor :
batch_size , _ , hidden_size = hidden_state . shape
class_embedding = self . class_embedding . expand ( batch_size , 1 , hidden_size )
hidden_state = torch . cat ( [ class_embedding , hidden_state ] , dim = 1 )
return hidden_state
def forward (
self ,
pixel_values : torch . Tensor ,
aspect_ratio_ids : torch . Tensor ,
attention_mask : torch . Tensor ,
) - > torch . Tensor :
batch_size , num_concurrent_media , num_tiles , num_channels , height , width = (
pixel_values . shape
)
pixel_values = pixel_values . reshape (
batch_size * num_concurrent_media * num_tiles , num_channels , height , width
)
aspect_ratio_ids = aspect_ratio_ids . reshape (
batch_size * num_concurrent_media , - 1
)
# patch embedding
2024-09-23 13:25:26 +00:00
patch_embeds = self . patch_embedding ( pixel_values )
2024-09-18 15:01:36 +00:00
hidden_state = patch_embeds . flatten ( 2 ) . transpose ( 1 , 2 )
# tile embeddings
_ , num_patches , dim = hidden_state . shape
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media , num_tiles , - 1 , dim
)
hidden_state = self . pre_tile_positional_embedding (
hidden_state , aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media * num_tiles , num_patches , dim
)
hidden_state = self . apply_class_embedding ( hidden_state )
num_patches + = 1
# apply position embeddings
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media , num_tiles , num_patches , dim
)
hidden_state = self . gated_positional_embedding ( hidden_state , aspect_ratio_ids )
# apply encoder
hidden_state = self . layernorm_pre ( hidden_state )
# Compute the number of tokens to pad
num_padding_patches = ( 8 - ( hidden_state . shape [ - 2 ] % 8 ) ) % 8
# Compute padding tuple for pad function
padding = (
0 ,
0 ,
0 ,
num_padding_patches ,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F . pad ( hidden_state , padding , mode = " constant " , value = 0 )
slice_index = - num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None :
attention_mask = attention_mask . reshape (
batch_size * num_concurrent_media , - 1
)
attention_mask = _prepare_aspect_ratio_attention_mask (
aspect_ratio_mask = attention_mask ,
num_patches = self . num_patches ,
target_length = hidden_state . shape [ 2 ] ,
dtype = self . dtype ,
)
hidden_state = hidden_state . view ( batch_size * num_concurrent_media , - 1 , dim )
2024-09-23 13:25:26 +00:00
hidden_state , all_intermediate_hidden_states = self . transformer (
2024-09-18 15:01:36 +00:00
hidden_state ,
attention_mask = attention_mask ,
)
intermediate_hidden_states = [
hidden_state
for idx , hidden_state in enumerate ( all_intermediate_hidden_states )
if idx in self . intermediate_layers_indices
]
intermediate_hidden_states = torch . stack ( intermediate_hidden_states , dim = - 1 )
# apply global encoder
hidden_state = self . layernorm_post ( hidden_state )
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
dim ,
)
hidden_state = self . post_tile_positional_embedding (
hidden_state , aspect_ratio_ids
)
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles * ( num_patches + num_padding_patches ) ,
dim ,
)
2024-09-23 13:25:26 +00:00
hidden_state , _ = self . global_transformer (
2024-09-18 15:01:36 +00:00
hidden_state , attention_mask = attention_mask
2024-09-23 13:25:26 +00:00
)
2024-09-18 15:01:36 +00:00
hidden_state = hidden_state . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
dim ,
)
hidden_state = hidden_state [ : , : , : slice_index ]
# adding intermediate layer outputs
hidden_state = hidden_state . reshape (
batch_size , num_concurrent_media , num_tiles , num_patches , dim
)
intermediate_hidden_states = intermediate_hidden_states . reshape (
batch_size * num_concurrent_media ,
num_tiles ,
num_patches + num_padding_patches ,
- 1 ,
)
intermediate_hidden_states = intermediate_hidden_states [ : , : , : slice_index ]
intermediate_hidden_states = intermediate_hidden_states . reshape (
batch_size , num_concurrent_media , num_tiles , num_patches , - 1
)
hidden_state = torch . cat ( [ hidden_state , intermediate_hidden_states ] , dim = - 1 )
return hidden_state
class MllamaTextCrossAttention ( nn . Module ) :
""" Multi-headed attention from ' Attention Is All You Need ' paper """
2024-09-23 13:25:26 +00:00
def __init__ ( self , * , prefix , config , weights , layer_idx ) :
2024-09-18 15:01:36 +00:00
super ( ) . __init__ ( )
self . config = config
self . num_heads = self . config . num_attention_heads
self . num_key_value_heads = self . config . num_key_value_heads
self . dropout = config . dropout
self . hidden_size = config . hidden_size
self . head_dim = config . hidden_size / / self . num_heads
self . num_key_value_groups = self . num_heads / / self . num_key_value_heads
2024-09-23 13:25:26 +00:00
self . layer_idx = layer_idx
2024-09-18 15:01:36 +00:00
2024-09-23 13:25:26 +00:00
self . num_heads = self . num_heads / / weights . process_group . size ( )
self . num_key_value_heads = (
self . num_key_value_heads / / weights . process_group . size ( )
)
self . q_proj = TensorParallelColumnLinear . load (
2024-09-18 15:01:36 +00:00
config ,
2024-09-23 13:25:26 +00:00
prefix = f " { prefix } .q_proj " ,
weights = weights ,
bias = False ,
)
self . k_proj = TensorParallelColumnLinear . load (
config ,
prefix = f " { prefix } .k_proj " ,
weights = weights ,
bias = False ,
)
self . v_proj = TensorParallelColumnLinear . load (
config ,
prefix = f " { prefix } .v_proj " ,
2024-09-18 15:01:36 +00:00
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
2024-09-23 13:25:26 +00:00
self . q_norm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .q_norm " , weights = weights , eps = config . rms_norm_eps
)
2024-09-23 13:25:26 +00:00
self . k_norm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .k_norm " , weights = weights , eps = config . rms_norm_eps
)
def forward (
self ,
hidden_states : torch . Tensor ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
past_key_value = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
""" Input shape: Batch x Time x Channel """
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
query_states = query_states . view (
bsz , q_len , self . num_heads , self . head_dim
) . transpose ( 1 , 2 )
query_states = self . q_norm ( query_states )
if cross_attention_states is not None :
key_states = self . k_proj ( cross_attention_states )
value_states = self . v_proj ( cross_attention_states )
key_states = key_states . view (
bsz , - 1 , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
value_states = value_states . view (
bsz , - 1 , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
key_states = self . k_norm ( key_states )
2024-09-23 13:25:26 +00:00
2024-09-18 15:01:36 +00:00
if past_key_value is not None :
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states , value_states = past_key_value . update (
key_states ,
value_states ,
self . layer_idx ,
{ " cache_position " : cache_position } ,
)
2024-09-23 13:25:26 +00:00
2024-09-18 15:01:36 +00:00
elif cache_position [ 0 ] != 0 :
key_states , value_states = (
past_key_value . key_cache [ self . layer_idx ] ,
past_key_value . value_cache [ self . layer_idx ] ,
)
else :
raise ValueError (
" Cross attention layer can ' t find neither `cross_attn_states` nor cached values for key/values! "
)
attn_weights = torch . matmul (
query_states , key_states . transpose ( 2 , 3 )
) / math . sqrt ( self . head_dim )
if attention_mask is not None : # no matter the length, we just slice it
causal_mask = attention_mask [ : , : , : , : key_states . shape [ - 2 ] ]
attn_weights = attn_weights + causal_mask
attn_weights = nn . functional . softmax (
attn_weights , dim = - 1 , dtype = torch . float32
) . to ( query_states . dtype )
attn_weights = nn . functional . dropout (
attn_weights , p = self . dropout , training = self . training
)
attn_output = torch . matmul ( attn_weights , value_states )
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . reshape ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
return attn_output , attn_weights , past_key_value
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . config = config
self . hidden_size = config . hidden_size
2024-09-23 13:25:26 +00:00
self . intermediate_size = (
config . intermediate_size / / weights . process_group . size ( )
)
2024-09-18 15:01:36 +00:00
self . gate_up_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .gate_proj " , f " { prefix } .up_proj " ] ,
weights = weights ,
dim = 0 ,
bias = False ,
)
self . down_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .down_proj " ,
weights = weights ,
bias = False ,
)
self . act_fn = ACT2FN [ config . hidden_activation ]
def forward ( self , x ) :
2024-09-23 13:25:26 +00:00
shape = x . shape
2024-09-18 15:01:36 +00:00
gate_up_states = self . gate_up_proj ( x )
2024-09-23 13:25:26 +00:00
gate_up_states = gate_up_states . view ( * shape [ : - 1 ] , 2 , self . intermediate_size )
result = self . down_proj (
self . act_fn ( gate_up_states [ : , : , 0 ] ) * gate_up_states [ : , : , 1 ]
)
return result
2024-09-18 15:01:36 +00:00
class MllamaCrossAttentionDecoderLayer ( torch . nn . Module ) :
""" Cross-attention transformer block with tanh-gated attention and feedforward. """
2024-09-23 13:25:26 +00:00
def __init__ ( self , * , prefix , config , weights , layer_idx ) - > None :
2024-09-18 15:01:36 +00:00
super ( ) . __init__ ( )
self . cross_attn = MllamaTextCrossAttention (
2024-09-23 13:25:26 +00:00
prefix = f " { prefix } .cross_attn " ,
config = config ,
weights = weights ,
layer_idx = layer_idx ,
2024-09-18 15:01:36 +00:00
)
2024-09-23 13:25:26 +00:00
self . input_layernorm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
self . cross_attn_attn_gate = torch . nn . Parameter (
weights . get_tensor ( f " { prefix } .cross_attn_attn_gate " ) , requires_grad = False
)
self . mlp = MllamaTextMLP ( prefix = f " { prefix } .mlp " , config = config , weights = weights )
2024-09-23 13:25:26 +00:00
self . post_attention_layernorm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
self . cross_attn_mlp_gate = torch . nn . Parameter (
weights . get_tensor ( f " { prefix } .cross_attn_mlp_gate " ) , requires_grad = False
)
def forward (
self ,
hidden_states : torch . Tensor ,
cross_attention_states : torch . Tensor ,
cross_attention_mask : torch . Tensor ,
attention_mask : torch . Tensor ,
full_text_row_masked_out_mask : Tuple [ torch . Tensor , torch . Tensor ] ,
past_key_value = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
* * kwargs ,
) - > torch . Tensor :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
hidden_states , attn_weights , past_key_value = self . cross_attn (
hidden_states = hidden_states ,
attention_mask = cross_attention_mask ,
cross_attention_states = cross_attention_states ,
past_key_value = past_key_value ,
cache_position = cache_position ,
)
hidden_states = residual + self . cross_attn_attn_gate . tanh ( ) * hidden_states
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
if full_text_row_masked_out_mask is not None :
hidden_states = full_text_row_masked_out_mask [ : , 0 ] * hidden_states # type: ignore
hidden_states = residual + self . cross_attn_mlp_gate . tanh ( ) * hidden_states
2024-09-23 13:25:26 +00:00
return hidden_states
2024-09-18 15:01:36 +00:00
class MllamaTextSelfAttention ( nn . Module ) :
2024-09-23 13:25:26 +00:00
def __init__ ( self , * , prefix , config , weights , layer_idx ) :
2024-09-18 15:01:36 +00:00
super ( ) . __init__ ( )
self . config = config
self . num_heads = config . num_attention_heads
self . dropout = config . dropout
self . hidden_size = config . hidden_size
self . num_key_value_heads = config . num_key_value_heads
self . head_dim = config . hidden_size / / self . num_heads
self . num_key_value_groups = self . num_heads / / self . num_key_value_heads
2024-09-23 13:25:26 +00:00
self . num_heads = self . num_heads / / weights . process_group . size ( )
self . num_key_value_heads = (
self . num_key_value_heads / / weights . process_group . size ( )
)
self . layer_idx = layer_idx
2024-09-18 15:01:36 +00:00
self . qkv_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config ,
prefix = f " { prefix } .o_proj " ,
weights = weights ,
bias = False ,
)
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : torch . Tensor ,
position_embeddings : torch . Tensor ,
past_key_value = None ,
cache_position = None ,
* * kwargs ,
) :
bsz , q_len , _ = hidden_states . size ( )
2024-09-23 13:25:26 +00:00
qkv = self . qkv_proj ( hidden_states )
query_states , key_states , value_states = qkv . split (
[
self . head_dim * self . num_heads ,
self . head_dim * self . num_key_value_heads ,
self . head_dim * self . num_key_value_heads ,
] ,
dim = 2 ,
)
2024-09-18 15:01:36 +00:00
query_states = query_states . view (
bsz , q_len , self . num_heads , self . head_dim
) . transpose ( 1 , 2 )
key_states = key_states . view (
bsz , q_len , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
value_states = value_states . view (
bsz , q_len , self . num_key_value_heads , self . head_dim
) . transpose ( 1 , 2 )
cos , sin = position_embeddings
query_states , key_states = apply_rotary_pos_emb (
query_states , key_states , cos , sin
)
if past_key_value is not None :
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position }
key_states , value_states = past_key_value . update (
key_states , value_states , self . layer_idx , cache_kwargs
)
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
causal_mask = attention_mask
if attention_mask is not None :
causal_mask = causal_mask [ : , : , : , : key_states . shape [ - 2 ] ]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states . device . type == " cuda " and causal_mask is not None :
query_states = query_states . contiguous ( )
key_states = key_states . contiguous ( )
value_states = value_states . contiguous ( )
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch . nn . functional . scaled_dot_product_attention (
query_states ,
key_states ,
value_states ,
2024-09-23 13:25:26 +00:00
# TODO
# attn_mask=causal_mask,
2024-09-18 15:01:36 +00:00
dropout_p = self . dropout if self . training else 0.0 ,
is_causal = is_causal ,
)
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . view ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
return attn_output , None , past_key_value
2024-09-23 13:25:26 +00:00
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm ( nn . Module ) :
def __init__ ( self , weight , eps ) :
super ( ) . __init__ ( )
self . weight = weight
self . variance_epsilon = eps
@classmethod
def load ( cls , * , prefix , weights , eps ) :
weight = nn . Parameter (
weights . get_tensor ( f " { prefix } .weight " ) , requires_grad = False
)
return cls ( weight = weight , eps = eps )
def forward ( self , hidden_states ) :
input_dtype = hidden_states . dtype
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt ( variance + self . variance_epsilon )
return self . weight * hidden_states . to ( input_dtype )
def extra_repr ( self ) :
return f " { tuple ( self . weight . shape ) } , eps= { self . variance_epsilon } "
2024-09-18 15:01:36 +00:00
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
class MllamaSelfAttentionDecoderLayer ( nn . Module ) :
2024-09-23 13:25:26 +00:00
def __init__ ( self , * , prefix , config , weights , layer_idx ) :
2024-09-18 15:01:36 +00:00
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . self_attn = MllamaTextSelfAttention (
2024-09-23 13:25:26 +00:00
prefix = f " { prefix } .self_attn " ,
config = config ,
weights = weights ,
layer_idx = layer_idx ,
2024-09-18 15:01:36 +00:00
)
self . mlp = MllamaTextMLP ( prefix = f " { prefix } .mlp " , config = config , weights = weights )
2024-09-23 13:25:26 +00:00
self . input_layernorm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
2024-09-23 13:25:26 +00:00
self . post_attention_layernorm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
position_embeddings : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None , # will become mandatory in v4.45
* * kwargs ,
) - > Tuple [
torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ]
] :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
# Self Attention
hidden_states , self_attn_weights , present_key_value = self . self_attn (
hidden_states = hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_value = past_key_value ,
cache_position = cache_position ,
position_embeddings = position_embeddings ,
* * kwargs ,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = residual + hidden_states
2024-09-23 13:25:26 +00:00
return hidden_states
2024-09-18 15:01:36 +00:00
2024-09-23 13:25:26 +00:00
class MllamaRotaryEmbedding ( nn . Module ) :
def __init__ (
self ,
* ,
config ,
weights ,
) :
super ( ) . __init__ ( )
device = weights . device
self . rope_type = config . rope_scaling [ " rope_type " ]
self . max_seq_len_cached = config . max_position_embeddings
self . original_max_seq_len = config . max_position_embeddings
2024-09-18 15:01:36 +00:00
2024-09-23 13:25:26 +00:00
self . config = config
self . rope_init_fn = ROPE_INIT_FUNCTIONS [ self . rope_type ]
inv_freq , self . attention_scaling = self . rope_init_fn ( self . config , device )
inv_freq . to ( device = device )
self . register_buffer ( " inv_freq " , inv_freq , persistent = False )
self . original_inv_freq = self . inv_freq
def _dynamic_frequency_update ( self , position_ids , device ) :
"""
dynamic RoPE layers should recompute ` inv_freq ` in the following situations :
1 - growing beyond the cached sequence length ( allow scaling )
2 - the current sequence length is in the original scale ( avoid losing precision with small sequences )
"""
seq_len = torch . max ( position_ids ) + 1
if seq_len > self . max_seq_len_cached : # growth
inv_freq , self . attention_scaling = self . rope_init_fn (
self . config , device , seq_len = seq_len , * * self . rope_kwargs
)
self . register_buffer (
" inv_freq " , inv_freq , persistent = False
) # TODO joao: may break with compilation
self . max_seq_len_cached = seq_len
if (
seq_len < self . original_max_seq_len
and self . max_seq_len_cached > self . original_max_seq_len
) : # reset
self . register_buffer ( " inv_freq " , self . original_inv_freq , persistent = False )
self . max_seq_len_cached = self . original_max_seq_len
@torch.no_grad ( )
def forward ( self , x , position_ids ) :
if " dynamic " in self . rope_type :
self . _dynamic_frequency_update ( position_ids , device = x . device )
# Core RoPE block
inv_freq_expanded = (
self . inv_freq [ None , : , None ] . float ( ) . expand ( position_ids . shape [ 0 ] , - 1 , 1 )
)
position_ids_expanded = position_ids [ : , None , : ] . float ( )
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x . device . type
device_type = (
device_type
if isinstance ( device_type , str ) and device_type != " mps "
else " cpu "
)
freqs = ( inv_freq_expanded . float ( ) @ position_ids_expanded . float ( ) ) . transpose (
1 , 2
)
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 )
cos = emb . cos ( )
sin = emb . sin ( )
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self . attention_scaling
sin = sin * self . attention_scaling
return cos . to ( dtype = x . dtype ) , sin . to ( dtype = x . dtype )
2024-09-18 15:01:36 +00:00
class MllamaTextModel ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . padding_idx = config . pad_token_id
self . vocab_size = config . vocab_size
2024-09-23 13:25:26 +00:00
self . config = config
2024-09-18 15:01:36 +00:00
self . embed_tokens = TensorParallelEmbedding (
prefix = f " { prefix } .embed_tokens " , weights = weights
)
self . cross_attention_layers = config . cross_attention_layers
self . layers = [ ]
for layer_idx in range ( config . num_hidden_layers ) :
if layer_idx in self . cross_attention_layers :
self . layers . append (
MllamaCrossAttentionDecoderLayer (
prefix = f " { prefix } .layers. { layer_idx } " ,
config = config ,
weights = weights ,
2024-09-23 13:25:26 +00:00
layer_idx = layer_idx ,
2024-09-18 15:01:36 +00:00
)
)
else :
self . layers . append (
MllamaSelfAttentionDecoderLayer (
prefix = f " { prefix } .layers. { layer_idx } " ,
config = config ,
weights = weights ,
2024-09-23 13:25:26 +00:00
layer_idx = layer_idx ,
2024-09-18 15:01:36 +00:00
)
)
# TODO Should we use this slow norm ?
# self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2024-09-23 13:25:26 +00:00
self . norm = MllamaTextRMSNorm . load (
2024-09-18 15:01:36 +00:00
prefix = f " { prefix } .norm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
# TODO Anything specific ?
head_size = config . hidden_size / / config . num_attention_heads
2024-09-23 13:25:26 +00:00
self . rotary_emb = MllamaRotaryEmbedding ( config = config , weights = weights )
2024-09-18 15:01:36 +00:00
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
cross_attention_states : Optional [ torch . FloatTensor ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
full_text_row_masked_out_mask : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None ,
past_key_values = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
) :
if ( input_ids is None ) ^ ( inputs_embeds is not None ) :
raise ValueError (
" You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one "
)
if inputs_embeds is None :
inputs_embeds = self . embed_tokens ( input_ids )
hidden_states = inputs_embeds
if cache_position is None :
past_seen_tokens = (
past_key_values . get_seq_length ( ) if past_key_values is not None else 0
)
cache_position = torch . arange (
past_seen_tokens ,
past_seen_tokens + inputs_embeds . shape [ 1 ] ,
device = inputs_embeds . device ,
)
if position_ids is None :
position_ids = cache_position . unsqueeze ( 0 )
2024-09-23 13:25:26 +00:00
causal_mask = self . _update_causal_mask (
attention_mask ,
inputs_embeds ,
cache_position ,
past_key_values ,
)
2024-09-18 15:01:36 +00:00
# create position embeddings to be shared across the decoder layers
position_embeddings = self . rotary_emb ( hidden_states , position_ids )
# decoder layers
for idx , decoder_layer in enumerate ( self . layers ) :
if (
idx in self . cross_attention_layers
and cross_attention_states is None
and (
past_key_values is None
or (
past_key_values is not None
and past_key_values . get_seq_length ( idx ) == 0
)
)
) :
continue
layer_outputs = decoder_layer (
hidden_states ,
cross_attention_states = cross_attention_states ,
cross_attention_mask = cross_attention_mask ,
attention_mask = causal_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
position_ids = position_ids ,
past_key_value = past_key_values ,
cache_position = cache_position ,
position_embeddings = position_embeddings ,
)
hidden_states = layer_outputs
hidden_states = self . norm ( hidden_states )
2024-09-23 13:25:26 +00:00
return BaseModelOutputWithPast (
last_hidden_state = hidden_states ,
past_key_values = past_key_values ,
)
def _update_causal_mask (
self ,
attention_mask : torch . Tensor ,
input_tensor : torch . Tensor ,
cache_position : torch . Tensor ,
past_key_values ,
) :
if self . config . _attn_implementation == " flash_attention_2 " :
if attention_mask is not None and 0.0 in attention_mask :
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values . get_seq_length ( ) if past_key_values is not None else 0
)
using_static_cache = isinstance ( past_key_values , StaticCache )
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
# TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
# self.config._attn_implementation == "sdpa" and
# if self.config._attn_implementation == "sdpa" and not using_static_cache:
if self . config . _attn_implementation == " sdpa " and not using_static_cache :
if AttentionMaskConverter . _ignore_causal_mask_sdpa (
attention_mask ,
inputs_embeds = input_tensor ,
past_key_values_length = past_seen_tokens ,
is_training = self . training ,
) :
return None
dtype , device = input_tensor . dtype , input_tensor . device
min_dtype = torch . finfo ( dtype ) . min
sequence_length = input_tensor . shape [ 1 ]
if using_static_cache :
target_length = past_key_values . get_max_length ( )
else :
target_length = (
attention_mask . shape [ - 1 ]
if isinstance ( attention_mask , torch . Tensor )
else past_seen_tokens + sequence_length + 1
)
2024-09-18 15:01:36 +00:00
2024-09-23 13:25:26 +00:00
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position (
attention_mask ,
sequence_length = sequence_length ,
target_length = target_length ,
dtype = dtype ,
device = device ,
min_dtype = min_dtype ,
cache_position = cache_position ,
batch_size = input_tensor . shape [ 0 ] ,
)
if (
self . config . _attn_implementation == " sdpa "
and attention_mask is not None
and attention_mask . device . type == " cuda "
) :
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter . _unmask_unattended (
causal_mask , min_dtype
)
return causal_mask
2024-09-18 15:01:36 +00:00
class MllamaForCausalLM ( nn . Module ) :
def __init__ ( self , * , prefix , config , weights ) :
super ( ) . __init__ ( )
self . vocab_size = config . vocab_size
self . model = MllamaTextModel (
prefix = f " { prefix } .model " , config = config , weights = weights
)
self . lm_head = SpeculativeHead . load (
prefix = f " { prefix } .lm_head " ,
config = config ,
weights = weights ,
)
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
cross_attention_states : Optional [ torch . LongTensor ] = None ,
cross_attention_mask : Optional [ torch . LongTensor ] = None ,
full_text_row_masked_out_mask : Optional [
Tuple [ torch . Tensor , torch . Tensor ]
] = None ,
past_key_values = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
num_logits_to_keep : int = 0 ,
) :
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2024-09-23 13:25:26 +00:00
# TODO
2024-09-18 15:01:36 +00:00
outputs = self . model (
input_ids = input_ids ,
cross_attention_states = cross_attention_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
cross_attention_mask = cross_attention_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
past_key_values = past_key_values ,
inputs_embeds = inputs_embeds ,
cache_position = cache_position ,
)
2024-09-23 13:25:26 +00:00
hidden_states = outputs . last_hidden_state
2024-09-18 15:01:36 +00:00
# if lm_head_indices is not None:
# hidden_states = hidden_states[lm_head_indices]
logits , speculative_logits = self . lm_head ( hidden_states )
2024-09-23 13:25:26 +00:00
return (
CausalLMOutputWithPast (
logits = logits ,
past_key_values = outputs . past_key_values ,
) ,
speculative_logits ,
)
2024-09-18 15:01:36 +00:00
def prepare_inputs_for_generation (
self ,
input_ids ,
past_key_values = None ,
attention_mask = None ,
inputs_embeds = None ,
cache_position = None ,
position_ids = None ,
num_logits_to_keep = None ,
* * kwargs ,
) :
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None :
if inputs_embeds is not None : # Exception 1
input_ids = input_ids [ : , - cache_position . shape [ 0 ] : ]
elif (
input_ids . shape [ 1 ] != cache_position . shape [ 0 ]
) : # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids [ : , cache_position ]
if attention_mask is not None and position_ids is None :
# create position_ids on the fly for batch generation
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
if past_key_values :
position_ids = position_ids [ : , - input_ids . shape [ 1 ] : ]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids . clone ( memory_format = torch . contiguous_format )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position [ 0 ] == 0 :
model_inputs = { " inputs_embeds " : inputs_embeds , " input_ids " : None }
else :
# The clone here is for the same reason as for `position_ids`.
model_inputs = {
" input_ids " : input_ids . clone ( memory_format = torch . contiguous_format ) ,
" inputs_embeds " : None ,
}
if isinstance ( past_key_values , StaticCache ) and attention_mask . ndim == 2 :
if model_inputs [ " inputs_embeds " ] is not None :
batch_size , sequence_length , _ = model_inputs [ " inputs_embeds " ] . shape
device = model_inputs [ " inputs_embeds " ] . device
else :
batch_size , sequence_length = model_inputs [ " input_ids " ] . shape
device = model_inputs [ " input_ids " ] . device
dtype = self . lm_head . weight . dtype
min_dtype = torch . finfo ( dtype ) . min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position (
attention_mask ,
sequence_length = sequence_length ,
target_length = past_key_values . get_max_length ( ) ,
dtype = dtype ,
device = device ,
min_dtype = min_dtype ,
cache_position = cache_position ,
batch_size = batch_size ,
)
if num_logits_to_keep is not None :
model_inputs [ " num_logits_to_keep " ] = num_logits_to_keep
model_inputs . update (
{
" position_ids " : position_ids ,
" cache_position " : cache_position ,
" past_key_values " : past_key_values ,
" attention_mask " : attention_mask ,
}
)
return model_inputs
class MllamaForConditionalGeneration ( nn . Module ) :
def __init__ ( self , prefix , config , weights ) :
super ( ) . __init__ ( )
config . vision_config . quantize = None
config . vision_config . speculator = config . speculator
config . text_config . quantize = config . quantize
config . text_config . speculator = config . speculator
2024-09-23 13:25:26 +00:00
# TODO check how this is determined
config . text_config . _attn_implementation = " sdpa "
# self.hidden_size = (
# config.text_config.hidden_size // weights.process_group.size()
# )
self . hidden_size = config . text_config . hidden_size
2024-09-18 15:01:36 +00:00
self . vision_model = MllamaVisionModel (
prefix = " vision_model " , config = config . vision_config , weights = weights
)
self . language_model = MllamaForCausalLM (
prefix = " language_model " , config = config . text_config , weights = weights
)
self . multi_modal_projector = FastLinear . load (
prefix = " multi_modal_projector " , config = config , weights = weights , bias = True
)
self . config = config
2024-09-23 13:25:26 +00:00
self . dtype = weights . dtype
self . device = weights . device
def forward (
self ,
input_ids : torch . LongTensor = None ,
pixel_values : Optional [ torch . FloatTensor ] = None ,
aspect_ratio_mask : Optional [ List [ List [ int ] ] ] = None ,
aspect_ratio_ids : Optional [ torch . Tensor ] = None ,
attention_mask : Optional [ List [ List [ List [ int ] ] ] ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ List [ torch . FloatTensor ] ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
num_logits_to_keep : int = 0 ,
image_hidden_states = None ,
image_attention_mask = None ,
) :
if past_key_values is None :
past_key_values = DynamicCache (
num_hidden_layers = self . config . text_config . num_hidden_layers
)
elif isinstance ( past_key_values , list ) :
past_key_values = DynamicCache . from_legacy_cache ( past_key_values )
if ( input_ids is None ) ^ ( inputs_embeds is not None ) :
raise ValueError (
" You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one "
)
if pixel_values is not None and inputs_embeds is not None :
raise ValueError (
" You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one "
)
if pixel_values is not None and cross_attention_states is not None :
raise ValueError (
" `pixel_values` and `cross_attention_states` cannot be provided simultaneously "
)
if pixel_values is not None :
if aspect_ratio_ids is None :
raise ValueError (
" `aspect_ratio_ids` must be provided if `pixel_values` is provided "
)
# get vision tokens from vision model
vision_states = self . vision_model (
pixel_values , aspect_ratio_ids , aspect_ratio_mask
)
cross_attention_states = self . multi_modal_projector ( vision_states ) . reshape (
- 1 , vision_states . shape [ - 2 ] , self . hidden_size
)
if cross_attention_mask is not None :
cross_attention_mask , full_text_row_masked_out_mask = (
_prepare_cross_attention_mask (
cross_attention_mask ,
num_vision_tokens = self . vision_model . num_patches ,
dtype = self . dtype ,
)
)
else :
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None :
cross_attention_mask = cross_attention_mask [ : , : , cache_position ]
full_text_row_masked_out_mask = full_text_row_masked_out_mask [
: , : , cache_position
]
outputs = self . language_model (
input_ids = input_ids ,
attention_mask = attention_mask ,
position_ids = position_ids ,
cross_attention_states = cross_attention_states ,
cross_attention_mask = cross_attention_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
past_key_values = past_key_values ,
inputs_embeds = inputs_embeds ,
labels = labels ,
cache_position = cache_position ,
num_logits_to_keep = num_logits_to_keep ,
)
return outputs