2024-10-02 09:22:13 +00:00
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mllama model. """
2025-02-27 15:12:15 +00:00
from typing import Optional , Tuple , List , Union
2024-10-02 09:22:13 +00:00
import torch
import torch . utils . checkpoint
from torch import nn
2025-02-27 15:12:15 +00:00
from loguru import logger
2024-10-02 09:22:13 +00:00
from transformers . activations import ACT2FN
2025-02-27 15:12:15 +00:00
from optimum . habana . transformers . models import GaudiMllamaForConditionalGeneration
from optimum . habana . transformers . models . mllama . modeling_mllama import _prepare_cross_attention_mask
from transformers . modeling_outputs import CausalLMOutputWithPast
2024-10-02 09:22:13 +00:00
2025-02-27 15:12:15 +00:00
class MllamaForConditionalGeneration ( GaudiMllamaForConditionalGeneration ) :
2024-10-02 09:22:13 +00:00
def forward (
self ,
2025-02-27 15:12:15 +00:00
input_ids : Optional [ torch . LongTensor ] = None ,
pixel_values : Optional [ torch . FloatTensor ] = None ,
aspect_ratio_mask : Optional [ torch . Tensor ] = None ,
aspect_ratio_ids : Optional [ torch . Tensor ] = None ,
2024-10-02 09:22:13 +00:00
attention_mask : Optional [ torch . Tensor ] = None ,
2025-02-27 15:12:15 +00:00
cross_attention_mask : Optional [ torch . Tensor ] = None ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ List [ torch . FloatTensor ] ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
cache_position : Optional [ torch . LongTensor ] = None ,
num_logits_to_keep : int = 0 ,
token_idx : Optional [ torch . Tensor ] = None ,
use_flash_attention : Optional [ bool ] = False ,
flash_attention_recompute : Optional [ bool ] = False ,
* * kwargs ,
) - > Union [ Tuple , CausalLMOutputWithPast ] :
"""
Copied from MllamaForConditionalGeneration : : forward : https : / / github . com / huggingface / transformers / blob / v4 .45 .2 / src / transformers / models / mllama / modeling_mllama . py #L2077
The only differences are :
- add token_idx input
- add use_flash_attention and flash_attention_recompute
"""
full_text_row_masked_out_mask = kwargs . get ( " full_text_row_masked_out_mask " , None )
if ( input_ids is None ) ^ ( inputs_embeds is not None ) :
raise ValueError (
" You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one "
2024-10-02 09:22:13 +00:00
)
2025-02-27 15:12:15 +00:00
outputs = self . language_model (
input_ids = input_ids ,
2024-10-02 09:22:13 +00:00
attention_mask = attention_mask ,
2025-02-27 15:12:15 +00:00
position_ids = position_ids ,
2024-10-02 09:22:13 +00:00
cross_attention_states = cross_attention_states ,
2025-02-27 15:12:15 +00:00
cross_attention_mask = cross_attention_mask ,
full_text_row_masked_out_mask = full_text_row_masked_out_mask ,
past_key_values = past_key_values ,
use_cache = use_cache ,
inputs_embeds = inputs_embeds ,
labels = labels ,
output_hidden_states = output_hidden_states ,
output_attentions = output_attentions ,
return_dict = return_dict ,
cache_position = cache_position ,
num_logits_to_keep = num_logits_to_keep ,
token_idx = token_idx ,
use_flash_attention = use_flash_attention ,
flash_attention_recompute = flash_attention_recompute ,
)
logits = outputs [ 0 ]
if not return_dict :
output = ( logits , ) + outputs [ 1 : ]
return output
2024-10-02 09:22:13 +00:00
2025-02-27 15:12:15 +00:00
return outputs
def prepare_inputs_for_generation (
2024-10-02 09:22:13 +00:00
self ,
2025-02-27 15:12:15 +00:00
input_ids = None ,
inputs_embeds = None ,
attention_mask = None ,
position_ids = None ,
pixel_values = None ,
aspect_ratio_ids = None ,
aspect_ratio_mask = None ,
cross_attention_mask = None ,
past_key_values = None ,
use_cache = False ,
cache_position = None ,
num_logits_to_keep = None ,
* * kwargs ,
2024-10-02 09:22:13 +00:00
) :
2025-02-27 15:12:15 +00:00
"""
Copied from MllamaForConditionalGeneration : : prepare_inputs_for_generation : https : / / github . com / huggingface / transformers / blob / v4 .45 .2 / src / transformers / models / mllama / modeling_mllama . py #L2208
The only differences are :
- add token_idx handling
- add bucket_internal handling
- add use_flash_attention and flash_attention_recompute
"""
token_idx = kwargs . get ( " token_idx " , None )
if token_idx is None :
return super ( ) . prepare_inputs_for_generation (
input_ids = input_ids ,
inputs_embeds = inputs_embeds ,
attention_mask = attention_mask ,
position_ids = position_ids ,
pixel_values = pixel_values ,
aspect_ratio_ids = aspect_ratio_ids ,
aspect_ratio_mask = aspect_ratio_mask ,
cross_attention_mask = cross_attention_mask ,
past_key_values = past_key_values ,
use_cache = use_cache ,
cache_position = cache_position ,
* * kwargs ,
)
else :
use_flash_attention = kwargs . get ( " use_flash_attention " , False )
flash_attention_recompute = kwargs . get ( " flash_attention_recompute " , False )
position_ids = kwargs . get ( " position_ids " , None )
output_attentions = kwargs . get ( " output_attentions " , None )
output_hidden_states = kwargs . get ( " output_hidden_states " , None )
return_dict = kwargs . get ( " return_dict " , None )
labels = kwargs . get ( " labels " , None )
cross_attention_states = kwargs . get ( " cross_attention_states " , None )
output_attentions = output_attentions if output_attentions is not None else self . config . output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self . config . output_hidden_states
)
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
bucket_internal = kwargs . get ( " bucket_internal " , None )
if past_key_values is not None :
if token_idx is not None :
input_ids = torch . index_select ( input_ids , 1 , token_idx - 1 )
elif inputs_embeds is not None : # Exception 1
input_ids = input_ids [ : , - cache_position . shape [ 0 ] : ]
elif input_ids . shape [ 1 ] != cache_position . shape [ 0 ] : # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids [ : , cache_position ]
elif bucket_internal and token_idx is not None :
# for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids [ : , : token_idx ]
attention_mask = attention_mask [ : , : token_idx ]
if cross_attention_mask is not None :
cross_attention_mask = cross_attention_mask [ : , : token_idx , . . . ]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None :
# create position_ids on the fly for batch generation
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
if past_key_values :
if token_idx is not None :
position_ids = torch . index_select ( position_ids , 1 , token_idx - 1 )
else :
position_ids = position_ids [ : , - input_ids . shape [ 1 ] : ]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids . clone ( memory_format = torch . contiguous_format )
if pixel_values is not None and inputs_embeds is not None :
raise ValueError (
" You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one "
2024-10-02 09:22:13 +00:00
)
2025-02-27 15:12:15 +00:00
if pixel_values is not None and cross_attention_states is not None :
raise ValueError ( " `pixel_values` and `cross_attention_states` cannot be provided simultaneously " )
if pixel_values is not None :
if aspect_ratio_ids is None :
raise ValueError ( " `aspect_ratio_ids` must be provided if `pixel_values` is provided " )
# get vision tokens from vision model
vision_outputs = self . vision_model (
pixel_values = pixel_values ,
aspect_ratio_ids = aspect_ratio_ids ,
aspect_ratio_mask = aspect_ratio_mask ,
output_hidden_states = output_hidden_states ,
output_attentions = output_attentions ,
return_dict = return_dict ,
use_flash_attention = use_flash_attention ,
2024-10-02 09:22:13 +00:00
)
2025-02-27 15:12:15 +00:00
cross_attention_states = vision_outputs [ 0 ]
cross_attention_states = self . multi_modal_projector ( cross_attention_states ) . reshape (
- 1 , cross_attention_states . shape [ - 2 ] , self . hidden_size
2024-10-02 09:22:13 +00:00
)
2025-02-27 15:12:15 +00:00
if cross_attention_mask is not None :
cross_attention_mask , full_text_row_masked_out_mask = _prepare_cross_attention_mask (
cross_attention_mask ,
num_vision_tokens = self . vision_model . num_patches ,
dtype = self . dtype ,
token_idx = token_idx ,
)
else :
full_text_row_masked_out_mask = None
if cross_attention_mask is not None :
if cache_position is not None :
cross_attention_mask = cross_attention_mask [ : , : , cache_position ]
full_text_row_masked_out_mask = full_text_row_masked_out_mask [ : , : , cache_position ]
elif past_key_values is not None :
if token_idx is not None :
cross_attention_mask = torch . index_select ( cross_attention_mask , - 2 , token_idx - 1 )
full_text_row_masked_out_mask = torch . index_select (
full_text_row_masked_out_mask , - 2 , token_idx - 1
)
else :
cross_attention_mask = cross_attention_mask [ : , : , - 1 : ]
full_text_row_masked_out_mask = full_text_row_masked_out_mask [ : , : , - 1 : ]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None :
model_inputs = { " inputs_embeds " : inputs_embeds , " input_ids " : None }
else :
# The clone here is for the same reason as for `position_ids`.
model_inputs = { " input_ids " : input_ids . clone ( memory_format = torch . contiguous_format ) , " inputs_embeds " : None }
if num_logits_to_keep is not None :
model_inputs [ " num_logits_to_keep " ] = num_logits_to_keep
# keep cache_position implementation as None for HPU
cache_position = None
model_inputs . update (
{
" position_ids " : position_ids ,
" past_key_values " : past_key_values ,
" use_cache " : kwargs . get ( " use_cache " ) ,
" attention_mask " : attention_mask ,
" token_idx " : token_idx ,
" labels " : labels ,
" return_dict " : kwargs . get ( " return_dict " ) ,
" full_text_row_masked_out_mask " : full_text_row_masked_out_mask ,
" use_flash_attention " : use_flash_attention ,
" cross_attention_mask " : cross_attention_mask ,
" cross_attention_states " : cross_attention_states ,
" output_attentions " : output_attentions ,
" flash_attention_recompute " : flash_attention_recompute ,
}
2024-10-02 09:22:13 +00:00
)
2025-02-27 15:12:15 +00:00
return model_inputs