From cd57fea11bb699369277cbe404c9b30b1d4f137a Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Thu, 6 Mar 2025 17:12:21 +0800 Subject: [PATCH] Fix Llava next crash issue (#285) Signed-off-by: yuanwu --- .../models/custom_modeling/llava_next.py | 265 +++++++++++++----- .../models/vlm_causal_lm.py | 17 +- 2 files changed, 198 insertions(+), 84 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index c58644b2..fa2d59f8 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,14 +14,13 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch import torch.utils.checkpoint -from torch import nn +import numpy as np from loguru import logger -from transformers.activations import ACT2FN from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) @@ -50,25 +49,44 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size +# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - - def _merge_input_ids_with_image_features( - self, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - input_ids: torch.Tensor, - ): - """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index - # Let's pray we have enabled enough slots ! - try: - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - except Exception as e: - raise RuntimeError( - f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" - ) - return inputs_embeds def forward( self, @@ -121,6 +139,136 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return output return outputs + #Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Args: + image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + if ( + np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 + and vision_feature_select_strategy == "default" + ): + logger.warning_once( + "Image feature shape does not line up with the provided patch size. " + "You may be using the `default` vision_feature_select_strategy with a" + " visual encoder that does not have CLS." + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + return image_features, feature_lens + + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_features.hidden_states[vision_feature_layer] + else: + hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features def prepare_inputs_for_generation( self, @@ -170,68 +318,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images - batch_size, num_patches, num_channels, height, width = pixel_values.shape - reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) - image_features = self.vision_tower( - reshaped_pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, ) - selected_image_feature = image_features.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) - if height * width != base_image_feature.shape[0]: - raise ValueError("The number of patches is not consistent with the image size.") + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx].tolist(), - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) - self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f17e9c24..f7e159a2 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -391,15 +391,17 @@ class VlmCausalLMBatch(CausalLMBatch): elif chunk_type == "image": image = Image.open(BytesIO(chunk.image.data)) # TODO unsure about BOS - if config.model_type == "mllama": - curr_text = image_text_replacement(config) + curr_text - else: - curr_text += image_text_replacement(config) curr_image = image curr_i = i else: raise RuntimeError(f"Invalid chunk type {chunk_type}") + if image_text_replacement(config) not in curr_text: + if "" in curr_text: + curr_text = curr_text.replace("", image_text_replacement(config)) + else: + curr_text = image_text_replacement(config) + curr_text + texts.append(curr_text) if curr_image is not None: if config.model_type == "mllama": @@ -416,18 +418,17 @@ class VlmCausalLMBatch(CausalLMBatch): dummy_inputs = [] if len(texts) > 0: dummy_inputs = [texts[0]] * missing_inputs - if config.model_type == "mllama": - dummy_images = [images[0]] * missing_inputs - else: - dummy_images = [images[0]] * missing_inputs + dummy_images = [images[0]] * missing_inputs texts += dummy_inputs images += dummy_images + processor_output = processor(images, texts, truncation=True, max_length=r.truncate, add_special_tokens=r.add_special_tokens, return_tensors="pt", + padding_side="left", padding="longest") if "input_ids" in processor_output: batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]})