mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 04:22:08 +00:00
Fix Llava next crash issue (#285)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
20ea73c6d4
commit
cd57fea11b
@ -14,14 +14,13 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
@ -50,26 +49,45 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
"""
|
||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
||||
The size of the input image in the format (height, width). ?
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
int: the number of patches
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
||||
image_size = image_size.tolist()
|
||||
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
try:
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -121,6 +139,136 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
return output
|
||||
|
||||
return outputs
|
||||
#Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
|
||||
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
Args:
|
||||
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_select_strategy (`str`)
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
Returns:
|
||||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||
feature_lens (`List[int]`)
|
||||
token length of each image in image_features
|
||||
"""
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
if (
|
||||
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
|
||||
and vision_feature_select_strategy == "default"
|
||||
):
|
||||
logger.warning_once(
|
||||
"Image feature shape does not line up with the provided patch size. "
|
||||
"You may be using the `default` vision_feature_select_strategy with a"
|
||||
" visual encoder that does not have CLS."
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device, image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
@ -170,68 +318,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
||||
image_features = self.vision_tower(
|
||||
reshaped_pixel_values,
|
||||
output_hidden_states=True,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
||||
),
|
||||
dim=-1,
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
|
||||
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
|
||||
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
|
@ -391,15 +391,17 @@ class VlmCausalLMBatch(CausalLMBatch):
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# TODO unsure about BOS
|
||||
if config.model_type == "mllama":
|
||||
curr_text = image_text_replacement(config) + curr_text
|
||||
else:
|
||||
curr_text += image_text_replacement(config)
|
||||
curr_image = image
|
||||
curr_i = i
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if image_text_replacement(config) not in curr_text:
|
||||
if "<image>" in curr_text:
|
||||
curr_text = curr_text.replace("<image>", image_text_replacement(config))
|
||||
else:
|
||||
curr_text = image_text_replacement(config) + curr_text
|
||||
|
||||
texts.append(curr_text)
|
||||
if curr_image is not None:
|
||||
if config.model_type == "mllama":
|
||||
@ -416,18 +418,17 @@ class VlmCausalLMBatch(CausalLMBatch):
|
||||
dummy_inputs = []
|
||||
if len(texts) > 0:
|
||||
dummy_inputs = [texts[0]] * missing_inputs
|
||||
if config.model_type == "mllama":
|
||||
dummy_images = [images[0]] * missing_inputs
|
||||
else:
|
||||
dummy_images = [images[0]] * missing_inputs
|
||||
texts += dummy_inputs
|
||||
images += dummy_images
|
||||
|
||||
processor_output = processor(images,
|
||||
texts,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
return_tensors="pt",
|
||||
padding_side="left",
|
||||
padding="longest")
|
||||
if "input_ids" in processor_output:
|
||||
batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]})
|
||||
|
Loading…
Reference in New Issue
Block a user