Fix the errors of style check

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-03-21 05:47:18 +00:00
parent bb55318f81
commit 3c6630c6e9
2 changed files with 200 additions and 142 deletions

View File

@ -27,6 +27,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -49,6 +50,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
""" """
@ -72,7 +74,9 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)): if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist() image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints) best_resolution = select_best_resolution(image_size, grid_pinpoints)
@ -86,6 +90,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
num_patches += 1 num_patches += 1
return num_patches return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def forward( def forward(
@ -110,11 +115,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
): ):
if token_idx is not None: if token_idx is not None:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -141,7 +154,13 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
):
""" """
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -165,7 +184,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if image_feature.shape[0] > 1: if image_feature.shape[0] > 1:
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape( num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx], image_sizes[image_idx],
@ -174,7 +196,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
) )
if ( if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 np.prod(image_feature.shape)
% (num_patch_height * num_patch_width * height * width)
!= 0
and vision_feature_select_strategy == "default" and vision_feature_select_strategy == "default"
): ):
logger.warning_once( logger.warning_once(
@ -183,7 +207,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
" visual encoder that does not have CLS." " visual encoder that does not have CLS."
) )
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = unpad_image(image_feature, image_sizes[image_idx])
@ -202,11 +228,15 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
else: else:
image_feature = image_feature[0] image_feature = image_feature[0]
if image_newline is not None: if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature) new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0)) feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0) image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) feature_lens = torch.tensor(
feature_lens, dtype=torch.long, device=image_features.device
)
return image_features, feature_lens return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
@ -247,11 +277,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
] ]
if pixel_values.dim() == 5: if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width) # stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] _pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0) pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4: elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width) # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True) image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states, # If we have one vision feature layer, return the corresponding hidden states,
@ -259,7 +294,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if isinstance(vision_feature_layer, int): if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
else: else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1) selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default": if vision_feature_select_strategy == "default":
@ -304,8 +342,14 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None) labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: if (
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None) vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = ( vision_feature_select_strategy = (
vision_feature_select_strategy vision_feature_select_strategy
@ -313,7 +357,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
else self.config.vision_feature_select_strategy else self.config.vision_feature_select_strategy
) )
vision_feature_layer = ( vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
) )
# 1. Extract the input embeddings # 1. Extract the input embeddings
@ -334,8 +380,12 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
image_newline=self.image_newline, image_newline=self.image_newline,
) )
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = (
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) input_ids == self.config.image_token_index
).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if inputs_embeds[special_image_mask].numel() != image_features.numel(): if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum() n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
@ -343,8 +393,12 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache # generation with cache
@ -356,7 +410,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
# that are set to 0 # that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length # Get the target length
past_length = first_layer_past_key_value.shape[-1] past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones( extended_attention_mask = torch.ones(
@ -383,7 +439,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
if token_idx is not None: if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else: else:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]