Fix the errors of style check

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-03-21 05:47:18 +00:00
parent bb55318f81
commit 3c6630c6e9
2 changed files with 200 additions and 142 deletions

View File

@ -27,6 +27,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -49,6 +50,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
""" """
@ -72,7 +74,9 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)): if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist() image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints) best_resolution = select_best_resolution(image_size, grid_pinpoints)
@ -86,6 +90,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
num_patches += 1 num_patches += 1
return num_patches return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def forward( def forward(
@ -110,11 +115,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
): ):
if token_idx is not None: if token_idx is not None:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -140,8 +153,14 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
):
""" """
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -165,7 +184,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if image_feature.shape[0] > 1: if image_feature.shape[0] > 1:
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape( num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx], image_sizes[image_idx],
@ -174,7 +196,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
) )
if ( if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 np.prod(image_feature.shape)
% (num_patch_height * num_patch_width * height * width)
!= 0
and vision_feature_select_strategy == "default" and vision_feature_select_strategy == "default"
): ):
logger.warning_once( logger.warning_once(
@ -183,7 +207,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
" visual encoder that does not have CLS." " visual encoder that does not have CLS."
) )
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = unpad_image(image_feature, image_sizes[image_idx])
@ -202,11 +228,15 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
else: else:
image_feature = image_feature[0] image_feature = image_feature[0]
if image_newline is not None: if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature) new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0)) feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0) image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) feature_lens = torch.tensor(
feature_lens, dtype=torch.long, device=image_features.device
)
return image_features, feature_lens return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
@ -247,11 +277,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
] ]
if pixel_values.dim() == 5: if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width) # stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] _pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0) pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4: elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width) # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True) image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states, # If we have one vision feature layer, return the corresponding hidden states,
@ -259,7 +294,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if isinstance(vision_feature_layer, int): if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
else: else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1) selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default": if vision_feature_select_strategy == "default":
@ -272,138 +310,158 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return image_features return image_features
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
pixel_values=None, pixel_values=None,
image_sizes=None, image_sizes=None,
attention_mask=None, attention_mask=None,
**kwargs, **kwargs,
): ):
""" """
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are: The only differences are:
- add new args token_idx - add new args token_idx
- add the process of merging images into inputs_embeds - add the process of merging images into inputs_embeds
""" """
token_idx = kwargs.get("token_idx", None) token_idx = kwargs.get("token_idx", None)
if token_idx is None: if token_idx is None:
return super().prepare_inputs_for_generation( return super().prepare_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
pixel_values=pixel_values, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs, **kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
) )
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None) # 1. Extract the input embeddings
labels = kwargs.get("labels", None) inputs_embeds = self.get_input_embeddings()(input_ids)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: # 2. Merge text and images
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) image_features = self.get_image_features(
vision_feature_layer = kwargs.get("vision_feature_layer", None) pixel_values,
vision_feature_select_strategy = ( image_sizes,
vision_feature_select_strategy vision_feature_layer=vision_feature_layer,
if vision_feature_select_strategy is not None vision_feature_select_strategy=vision_feature_select_strategy,
else self.config.vision_feature_select_strategy )
)
vision_feature_layer = ( # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (
input_ids == self.config.image_token_index
).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
# 1. Extract the input embeddings image_features = image_features.to(
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds.device, inputs_embeds.dtype
# 2. Merge text and images )
image_features = self.get_image_features( inputs_embeds = inputs_embeds.masked_scatter(
pixel_values, special_image_mask, image_features
image_sizes, )
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
image_features, feature_lens = self.pack_image_features( # generation with cache
image_features, elif past_key_values is not None:
image_sizes, seq_len = input_ids.shape[1]
vision_feature_select_strategy=vision_feature_select_strategy, pad_len = seq_len - token_idx.item()
image_newline=self.image_newline, input_ids = torch.index_select(input_ids, 1, token_idx - 1)
) # Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) # Zero-out the places where we don't need to attend
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum() attention_mask = extended_attention_mask
n_image_features = image_features.shape[0] attention_mask[:, -pad_len:] = 0
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
) )
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of model_inputs.update(
# generation with cache {
elif past_key_values is not None: "position_ids": position_ids,
seq_len = input_ids.shape[1] "past_key_values": past_key_values,
pad_len = seq_len - token_idx.item() "use_cache": kwargs.get("use_cache"),
input_ids = torch.index_select(input_ids, 1, token_idx - 1) "attention_mask": attention_mask,
# Retrieve the first layer to inspect the logits and mask out the hidden states "token_idx": token_idx,
# that are set to 0 "labels": labels,
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] "use_flash_attention": use_flash_attention,
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 "flash_attention_recompute": flash_attention_recompute,
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) }
# Get the target length )
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend return model_inputs
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs