mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
297 lines
15 KiB
Python
297 lines
15 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Llava-NeXT model."""
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
unpad_image,
|
|
)
|
|
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
"""
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
|
|
Args:
|
|
image_size (`tuple`):
|
|
The size of the input image in the format (width, height).
|
|
grid_pinpoints (`List`):
|
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
of the form `(height, width)`.
|
|
patch_size (`int`):
|
|
The size of each image patch.
|
|
|
|
Returns:
|
|
tuple: The shape of the image patch grid in the format (width, height).
|
|
"""
|
|
if not isinstance(grid_pinpoints, list):
|
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
|
|
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
return height // patch_size, width // patch_size
|
|
|
|
|
|
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|
|
|
def _merge_input_ids_with_image_features(
|
|
self,
|
|
inputs_embeds: torch.Tensor,
|
|
image_features: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
):
|
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
mask = input_ids == self.config.image_token_index
|
|
# Let's pray we have enabled enough slots !
|
|
try:
|
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
|
)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
image_sizes: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
vision_feature_layer: Optional[int] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
token_idx: Optional[torch.Tensor] = None,
|
|
use_flash_attention: Optional[bool] = False,
|
|
flash_attention_recompute: Optional[bool] = False,
|
|
):
|
|
|
|
if token_idx is not None:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
outputs = self.language_model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
token_idx=token_idx,
|
|
use_flash_attention=use_flash_attention,
|
|
flash_attention_recompute=flash_attention_recompute,
|
|
)
|
|
|
|
logits = outputs[0]
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return output
|
|
|
|
return outputs
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
pixel_values=None,
|
|
image_sizes=None,
|
|
attention_mask=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
|
The only differences are:
|
|
- add new args token_idx
|
|
- add the process of merging images into inputs_embeds
|
|
"""
|
|
token_idx = kwargs.get("token_idx", None)
|
|
if token_idx is None:
|
|
return super().prepare_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
pixel_values=pixel_values,
|
|
image_sizes=image_sizes,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
use_flash_attention = kwargs.get("use_flash_attention", False)
|
|
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
|
|
|
position_ids = kwargs.get("position_ids", None)
|
|
labels = kwargs.get("labels", None)
|
|
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
|
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
|
|
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
|
vision_feature_select_strategy = (
|
|
vision_feature_select_strategy
|
|
if vision_feature_select_strategy is not None
|
|
else self.config.vision_feature_select_strategy
|
|
)
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
|
|
# 1. Extract the input embeddings
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
# 2. Merge text and images
|
|
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
|
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
|
image_features = self.vision_tower(
|
|
reshaped_pixel_values,
|
|
output_hidden_states=True,
|
|
use_flash_attention=use_flash_attention,
|
|
flash_attention_recompute=flash_attention_recompute,
|
|
)
|
|
|
|
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
|
|
|
if vision_feature_select_strategy == "default":
|
|
selected_image_feature = selected_image_feature[:, 1:]
|
|
elif vision_feature_select_strategy == "full":
|
|
selected_image_feature = selected_image_feature
|
|
|
|
image_features = self.multi_modal_projector(selected_image_feature)
|
|
|
|
# split up image_features for each of the individual images
|
|
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
|
# if we assume each image has 5 image features (base image + 4 patches)
|
|
split_sizes = [image.shape[0] for image in pixel_values]
|
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
|
|
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
|
|
|
new_image_features = []
|
|
for image_idx, image_feature in enumerate(image_features):
|
|
if image_feature.shape[0] > 1:
|
|
base_image_feature = image_feature[0]
|
|
image_feature = image_feature[1:]
|
|
|
|
if height * width != base_image_feature.shape[0]:
|
|
raise ValueError("The number of patches is not consistent with the image size.")
|
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
image_sizes[image_idx].tolist(),
|
|
self.config.image_grid_pinpoints,
|
|
self.config.vision_config.image_size,
|
|
)
|
|
|
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
|
image_feature = torch.cat(
|
|
(
|
|
image_feature,
|
|
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
|
),
|
|
dim=-1,
|
|
)
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
|
else:
|
|
image_feature = image_feature[0]
|
|
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
|
new_image_features.append(image_feature)
|
|
image_features = torch.stack(new_image_features, dim=0)
|
|
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
|
|
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
|
|
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
|
# generation with cache
|
|
elif past_key_values is not None:
|
|
seq_len = input_ids.shape[1]
|
|
pad_len = seq_len - token_idx.item()
|
|
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
# that are set to 0
|
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
|
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
|
|
# Get the target length
|
|
past_length = first_layer_past_key_value.shape[-1]
|
|
extended_attention_mask = torch.ones(
|
|
(attention_mask.shape[0], past_length),
|
|
dtype=attention_mask.dtype,
|
|
device=attention_mask.device,
|
|
)
|
|
# Filter out only the tokens that can be un-attended, this can happen
|
|
# if one uses Llava + Fused modules where the cache on the
|
|
# first iteration is already big enough, or if one passes custom cache
|
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
new_batch_index = batch_index[valid_indices]
|
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
|
|
# Zero-out the places where we don't need to attend
|
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
|
|
attention_mask = extended_attention_mask
|
|
attention_mask[:, -pad_len:] = 0
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
# create position_ids on the fly for batch generation
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
if past_key_values:
|
|
if token_idx is not None:
|
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
else:
|
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"position_ids": position_ids,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"attention_mask": attention_mask,
|
|
"token_idx": token_idx,
|
|
"labels": labels,
|
|
"use_flash_attention": use_flash_attention,
|
|
"flash_attention_recompute": flash_attention_recompute,
|
|
}
|
|
)
|
|
|
|
return model_inputs
|