diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index fa2d59f8..afc64c7d 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,13 +14,14 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint -import numpy as np +from torch import nn from loguru import logger +from transformers.activations import ACT2FN from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) @@ -49,44 +50,25 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size -# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): + + def _merge_input_ids_with_image_features( + self, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + input_ids: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = input_ids == self.config.image_token_index + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + return inputs_embeds def forward( self, @@ -139,136 +121,6 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return output return outputs - #Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): - """ - Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. - - Args: - image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) - List of image feature tensor, each contains all the visual feature of all patches. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_select_strategy (`str`) - The feature selection strategy used to select the vision feature from the vision backbone. - image_newline (`torch.Tensor` of shape `(embed_dim)`) - New line embedding vector. - Returns: - image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) - feature_lens (`List[int]`) - token length of each image in image_features - """ - new_image_features = [] - feature_lens = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - if ( - np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 - and vision_feature_select_strategy == "default" - ): - logger.warning_once( - "Image feature shape does not line up with the provided patch size. " - "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS." - ) - - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - if image_newline is not None: - image_feature = torch.cat( - ( - image_feature, - image_newline[:, None, None] - .expand(*image_feature.shape[:-1], 1) - .to(image_feature.device, image_feature.dtype), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if image_newline is not None: - image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) - new_image_features.append(image_feature) - feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) - return image_features, feature_lens - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 - def get_image_features( - self, - pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, - ): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) - The tensors corresponding to the input images. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches - and are of shape `(num_patches, image_length, embed_dim)`). - """ - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - if pixel_values.dim() == 5: - # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_image_feature = image_features.hidden_states[vision_feature_layer] - else: - hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] - selected_image_feature = torch.cat(hs_pool, dim=-1) - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) - return image_features def prepare_inputs_for_generation( self, @@ -318,33 +170,67 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower( + reshaped_pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx].tolist(), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.cat(new_image_features, dim=0) + inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: