Fix Llava next crash issue (#285)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-03-06 17:12:21 +08:00 committed by GitHub
parent 20ea73c6d4
commit cd57fea11b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 198 additions and 84 deletions

View File

@ -14,14 +14,13 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn import numpy as np
from loguru import logger from loguru import logger
from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
unpad_image, unpad_image,
) )
@ -50,25 +49,44 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward( def forward(
self, self,
@ -121,6 +139,136 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
return output return output
return outputs return outputs
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
@ -170,68 +318,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
# 1. Extract the input embeddings # 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images # 2. Merge text and images
batch_size, num_patches, num_channels, height, width = pixel_values.shape image_features = self.get_image_features(
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) pixel_values,
image_features = self.vision_tower( image_sizes,
reshaped_pixel_values, vision_feature_layer=vision_feature_layer,
output_hidden_states=True, vision_feature_select_strategy=vision_feature_select_strategy,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
) )
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
new_image_features = [] special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
for image_idx, image_feature in enumerate(image_features): special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if image_feature.shape[0] > 1: if inputs_embeds[special_image_mask].numel() != image_features.numel():
base_image_feature = image_feature[0] n_image_tokens = (input_ids == self.config.image_token_index).sum()
image_feature = image_feature[1:] n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
if height * width != base_image_feature.shape[0]: image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
raise ValueError("The number of patches is not consistent with the image size.") inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache # generation with cache
elif past_key_values is not None: elif past_key_values is not None:

View File

@ -391,15 +391,17 @@ class VlmCausalLMBatch(CausalLMBatch):
elif chunk_type == "image": elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data)) image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS # TODO unsure about BOS
if config.model_type == "mllama":
curr_text = image_text_replacement(config) + curr_text
else:
curr_text += image_text_replacement(config)
curr_image = image curr_image = image
curr_i = i curr_i = i
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
if image_text_replacement(config) not in curr_text:
if "<image>" in curr_text:
curr_text = curr_text.replace("<image>", image_text_replacement(config))
else:
curr_text = image_text_replacement(config) + curr_text
texts.append(curr_text) texts.append(curr_text)
if curr_image is not None: if curr_image is not None:
if config.model_type == "mllama": if config.model_type == "mllama":
@ -416,18 +418,17 @@ class VlmCausalLMBatch(CausalLMBatch):
dummy_inputs = [] dummy_inputs = []
if len(texts) > 0: if len(texts) > 0:
dummy_inputs = [texts[0]] * missing_inputs dummy_inputs = [texts[0]] * missing_inputs
if config.model_type == "mllama": dummy_images = [images[0]] * missing_inputs
dummy_images = [images[0]] * missing_inputs
else:
dummy_images = [images[0]] * missing_inputs
texts += dummy_inputs texts += dummy_inputs
images += dummy_images images += dummy_images
processor_output = processor(images, processor_output = processor(images,
texts, texts,
truncation=True, truncation=True,
max_length=r.truncate, max_length=r.truncate,
add_special_tokens=r.add_special_tokens, add_special_tokens=r.add_special_tokens,
return_tensors="pt", return_tensors="pt",
padding_side="left",
padding="longest") padding="longest")
if "input_ids" in processor_output: if "input_ids" in processor_output:
batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]}) batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]})