Gaudi: Fix llava-next and mllama crash issue (#3127)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-03-25 22:08:15 +08:00 committed by GitHub
parent 54d15462dc
commit f5f14dc660
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 116 additions and 72 deletions

View File

@ -20,6 +20,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import numpy as np import numpy as np
from loguru import logger
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
unpad_image, unpad_image,
) )
@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -169,6 +153,92 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return outputs return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape)
% (num_patch_height * num_patch_width * height * width)
!= 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(
feature_lens, dtype=torch.long, device=image_features.device
)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features( def get_image_features(
self, self,
@ -303,61 +373,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
) )
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = ( image_features, feature_lens = self.pack_image_features(
self.config.vision_config.image_size image_features,
// self.config.vision_config.patch_size image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
) )
new_image_features = [] special_image_mask = (
for image_idx, image_feature in enumerate(image_features): input_ids == self.config.image_token_index
if image_feature.shape[0] > 1: ).unsqueeze(-1)
base_image_feature = image_feature[0] special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
image_feature = image_feature[1:] inputs_embeds.device
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.cat(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
inputs_embeds, image_features, input_ids
) )
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache # generation with cache
elif past_key_values is not None: elif past_key_values is not None:

View File

@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch):
else: else:
images.append(curr_image) images.append(curr_image)
if is_warmup is True:
images += [images[0]] * (len(texts) - len(images))
missing_inputs = 0 missing_inputs = 0
dummy_images = None dummy_images = None
if is_warmup is False: if is_warmup is False:
@ -1464,7 +1467,6 @@ class VlmCausalLM(Model):
batch = self.batch_from_pb(request.batch, is_warmup=True) batch = self.batch_from_pb(request.batch, is_warmup=True)
max_input_tokens = request.max_input_tokens max_input_tokens = request.max_input_tokens
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True) _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
@ -1548,7 +1550,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
max_prefill_batch_size, max_prefill_batch_size,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True
@ -1568,7 +1570,7 @@ class VlmCausalLM(Model):
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
2, 2,
is_warmup=False, is_warmup=True,
) )
_, prefill_batch, _ = self.generate_token( _, prefill_batch, _ = self.generate_token(
[batch], is_warmup=True [batch], is_warmup=True