Fix llava-next and mllama crash issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-03-20 05:39:26 +00:00
parent e497bc09f6
commit bb55318f81
2 changed files with 216 additions and 230 deletions

View File

@ -20,13 +20,13 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import numpy as np import numpy as np
from loguru import logger
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
unpad_image, unpad_image,
) )
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -49,7 +49,6 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
""" """
@ -73,9 +72,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)): if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError( raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist() image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints) best_resolution = select_best_resolution(image_size, grid_pinpoints)
@ -89,26 +86,8 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
num_patches += 1 num_patches += 1
return num_patches return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -126,24 +105,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True, use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = True, flash_attention_recompute: Optional[bool] = False,
): ):
if token_idx is not None: if token_idx is not None:
output_attentions = ( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -169,6 +140,75 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features( def get_image_features(
self, self,
@ -207,16 +247,11 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
] ]
if pixel_values.dim() == 5: if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width) # stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0) pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4: elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width) # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError( raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True) image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states, # If we have one vision feature layer, return the corresponding hidden states,
@ -224,10 +259,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
if isinstance(vision_feature_layer, int): if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
else: else:
hs_pool = [ hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1) selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default": if vision_feature_select_strategy == "default":
@ -240,186 +272,138 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return image_features return image_features
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
pixel_values=None, pixel_values=None,
image_sizes=None, image_sizes=None,
attention_mask=None, attention_mask=None,
**kwargs, **kwargs,
): ):
""" """
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are: The only differences are:
- add new args token_idx - add new args token_idx
- add the process of merging images into inputs_embeds - add the process of merging images into inputs_embeds
""" """
token_idx = kwargs.get("token_idx", None) token_idx = kwargs.get("token_idx", None)
if token_idx is None: if token_idx is None:
return super().prepare_inputs_for_generation( return super().prepare_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
pixel_values=pixel_values, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs, **kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
) )
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.cat(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
inputs_embeds, image_features, input_ids
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
model_inputs.update( position_ids = kwargs.get("position_ids", None)
{ labels = kwargs.get("labels", None)
"position_ids": position_ids, if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
"past_key_values": past_key_values, vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
"use_cache": kwargs.get("use_cache"), vision_feature_layer = kwargs.get("vision_feature_layer", None)
"attention_mask": attention_mask, vision_feature_select_strategy = (
"token_idx": token_idx, vision_feature_select_strategy
"labels": labels, if vision_feature_select_strategy is not None
"use_flash_attention": use_flash_attention, else self.config.vision_feature_select_strategy
"flash_attention_recompute": flash_attention_recompute, )
} vision_feature_layer = (
) vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
return model_inputs # 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch):
else: else:
images.append(curr_image) images.append(curr_image)
if is_warmup is True:
images += [images[0]] * (len(texts) - len(images))
missing_inputs = 0 missing_inputs = 0
dummy_images = None dummy_images = None
if is_warmup is False: if is_warmup is False:
@ -1464,7 +1467,6 @@ class VlmCausalLM(Model):
batch = self.batch_from_pb(request.batch, is_warmup=True) batch = self.batch_from_pb(request.batch, is_warmup=True)
max_input_tokens = request.max_input_tokens max_input_tokens = request.max_input_tokens
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True) _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
@ -1548,7 +1550,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
max_prefill_batch_size, max_prefill_batch_size,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True
@ -1568,7 +1570,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
2, 2,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True