mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Gaudi: Fix llava-next and mllama crash issue (#3127)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
54d15462dc
commit
f5f14dc660
@ -20,6 +20,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
try:
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -169,6 +153,92 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
|
||||
def pack_image_features(
|
||||
self,
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy,
|
||||
image_newline=None,
|
||||
):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
Args:
|
||||
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_select_strategy (`str`)
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
Returns:
|
||||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||
feature_lens (`List[int]`)
|
||||
token length of each image in image_features
|
||||
"""
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
if (
|
||||
np.prod(image_feature.shape)
|
||||
% (num_patch_height * num_patch_width * height * width)
|
||||
!= 0
|
||||
and vision_feature_select_strategy == "default"
|
||||
):
|
||||
logger.warning_once(
|
||||
"Image feature shape does not line up with the provided patch size. "
|
||||
"You may be using the `default` vision_feature_select_strategy with a"
|
||||
" visual encoder that does not have CLS."
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device, image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(image_feature, image_newline[None].to(image_feature)), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(
|
||||
feature_lens, dtype=torch.long, device=image_features.device
|
||||
)
|
||||
return image_features, feature_lens
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||
def get_image_features(
|
||||
self,
|
||||
@ -303,61 +373,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
special_image_mask = (
|
||||
input_ids == self.config.image_token_index
|
||||
).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
image_features = image_features.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_image_mask, image_features
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
)
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
inputs_embeds, image_features, input_ids
|
||||
)
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
|
@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
||||
else:
|
||||
images.append(curr_image)
|
||||
|
||||
if is_warmup is True:
|
||||
images += [images[0]] * (len(texts) - len(images))
|
||||
|
||||
missing_inputs = 0
|
||||
dummy_images = None
|
||||
if is_warmup is False:
|
||||
@ -1464,7 +1467,6 @@ class VlmCausalLM(Model):
|
||||
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||
max_input_tokens = request.max_input_tokens
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
|
||||
try:
|
||||
# max prefill batch size warmup
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||
@ -1548,7 +1550,7 @@ class VlmCausalLM(Model):
|
||||
request,
|
||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||
max_prefill_batch_size,
|
||||
is_warmup=False,
|
||||
is_warmup=True,
|
||||
)
|
||||
_, prefill_batch, _ = self.generate_token(
|
||||
[batch], is_warmup=True
|
||||
@ -1568,7 +1570,7 @@ class VlmCausalLM(Model):
|
||||
request,
|
||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||
2,
|
||||
is_warmup=False,
|
||||
is_warmup=True,
|
||||
)
|
||||
_, prefill_batch, _ = self.generate_token(
|
||||
[batch], is_warmup=True
|
||||
|
Loading…
Reference in New Issue
Block a user