text-generation-inference/server/text_generation_server/models/custom_modeling/mllama.py

263 lines
13 KiB
Python
Raw Normal View History

# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List, Union
import torch
import torch.utils.checkpoint
from torch import nn
from loguru import logger
from transformers.activations import ACT2FN
from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration
from optimum.habana.transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
The only differences are:
- add token_idx input
- add use_flash_attention and flash_attention_recompute
"""
full_text_row_masked_out_mask = kwargs.get("full_text_row_masked_out_mask", None)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
"""
Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
The only differences are:
- add token_idx handling
- add bucket_internal handling
- add use_flash_attention and flash_attention_recompute
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None)
output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None)
return_dict = kwargs.get("return_dict", None)
labels = kwargs.get("labels", None)
cross_attention_states = kwargs.get("cross_attention_states", None)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bucket_internal = kwargs.get("bucket_internal", None)
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
elif inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
elif bucket_internal and token_idx is not None:
# for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
if cross_attention_mask is not None:
cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(position_ids, 1, token_idx - 1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
token_idx=token_idx,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None:
if cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
elif past_key_values is not None:
if token_idx is not None:
cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1)
full_text_row_masked_out_mask = torch.index_select(
full_text_row_masked_out_mask, -2, token_idx - 1
)
else:
cross_attention_mask = cross_attention_mask[:, :, -1:]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
# keep cache_position implementation as None for HPU
cache_position = None
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"return_dict": kwargs.get("return_dict"),
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
"use_flash_attention": use_flash_attention,
"cross_attention_mask": cross_attention_mask,
"cross_attention_states": cross_attention_states,
"output_attentions": output_attentions,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs