Fix llava-next and mllama crash issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-03-20 05:39:26 +00:00
parent e497bc09f6
commit bb55318f81
2 changed files with 216 additions and 230 deletions

View File

@ -20,13 +20,13 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import numpy as np import numpy as np
from loguru import logger
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
unpad_image, unpad_image,
) )
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -49,7 +49,6 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
""" """
@ -73,9 +72,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)): if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError( raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist() image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints) best_resolution = select_best_resolution(image_size, grid_pinpoints)
@ -89,26 +86,8 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
num_patches += 1 num_patches += 1
return num_patches return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -126,24 +105,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True, use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = True, flash_attention_recompute: Optional[bool] = False,
): ):
if token_idx is not None: if token_idx is not None:
output_attentions = ( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -169,6 +140,75 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features( def get_image_features(
self, self,
@ -207,16 +247,11 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
] ]
if pixel_values.dim() == 5: if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width) # stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0) pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4: elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width) # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError( raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True) image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states, # If we have one vision feature layer, return the corresponding hidden states,
@ -224,10 +259,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if isinstance(vision_feature_layer, int): if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
else: else:
hs_pool = [ hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1) selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default": if vision_feature_select_strategy == "default":
@ -267,19 +299,13 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
**kwargs, **kwargs,
) )
else: else:
use_flash_attention = kwargs.get("use_flash_attention", True) use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True) flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None) labels = kwargs.get("labels", None)
if ( if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
past_key_values is None vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None) vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = ( vision_feature_select_strategy = (
vision_feature_select_strategy vision_feature_select_strategy
@ -287,9 +313,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
else self.config.vision_feature_select_strategy else self.config.vision_feature_select_strategy
) )
vision_feature_layer = ( vision_feature_layer = (
vision_feature_layer vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
) )
# 1. Extract the input embeddings # 1. Extract the input embeddings
@ -303,61 +327,25 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
) )
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = ( image_features, feature_lens = self.pack_image_features(
self.config.vision_config.image_size image_features,
// self.config.vision_config.patch_size image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
) )
new_image_features = [] special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
for image_idx, image_feature in enumerate(image_features): special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if image_feature.shape[0] > 1: if inputs_embeds[special_image_mask].numel() != image_features.numel():
base_image_feature = image_feature[0] n_image_tokens = (input_ids == self.config.image_token_index).sum()
image_feature = image_feature[1:] n_image_features = image_features.shape[0]
if height * width != base_image_feature.shape[0]:
raise ValueError( raise ValueError(
"The number of patches is not consistent with the image size." f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
image_sizes[image_idx].tolist(), inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.cat(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
inputs_embeds, image_features, input_ids
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache # generation with cache
elif past_key_values is not None: elif past_key_values is not None:
@ -368,9 +356,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
# that are set to 0 # that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where( batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length # Get the target length
past_length = first_layer_past_key_value.shape[-1] past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones( extended_attention_mask = torch.ones(
@ -397,9 +383,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
if token_idx is not None: if token_idx is not None:
position_ids = ( position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else: else:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]

View File

@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch):
else: else:
images.append(curr_image) images.append(curr_image)
if is_warmup is True:
images += [images[0]] * (len(texts) - len(images))
missing_inputs = 0 missing_inputs = 0
dummy_images = None dummy_images = None
if is_warmup is False: if is_warmup is False:
@ -1464,7 +1467,6 @@ class VlmCausalLM(Model):
batch = self.batch_from_pb(request.batch, is_warmup=True) batch = self.batch_from_pb(request.batch, is_warmup=True)
max_input_tokens = request.max_input_tokens max_input_tokens = request.max_input_tokens
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True) _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
@ -1548,7 +1550,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
max_prefill_batch_size, max_prefill_batch_size,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True
@ -1568,7 +1570,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
2, 2,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True