# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Llava-NeXT model.""" from typing import List, Optional, Union import torch import torch.utils.checkpoint import numpy as np from loguru import logger from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): """ Calculate the number of patches after the preprocessing for images of any resolution. Args: image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): The size of the input image in the format (height, width). ? grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: int: the number of patches """ if not isinstance(grid_pinpoints, list): raise TypeError("grid_pinpoints should be a list of tuples or lists") # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (torch.Tensor, np.ndarray)): raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") image_size = image_size.tolist() best_resolution = select_best_resolution(image_size, grid_pinpoints) height, width = best_resolution num_patches = 0 # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 for i in range(0, height, patch_size): for j in range(0, width, patch_size): num_patches += 1 # add the base patch num_patches += 1 return num_patches class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, ): if token_idx is not None: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, ) logits = outputs[0] if not return_dict: output = (logits,) + outputs[1:] return output return outputs #Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Args: image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_select_strategy (`str`) The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) feature_lens (`List[int]`) token length of each image in image_features """ new_image_features = [] feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) if ( np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 and vision_feature_select_strategy == "default" ): logger.warning_once( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" " visual encoder that does not have CLS." ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) if image_newline is not None: image_feature = torch.cat( ( image_feature, image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) image_features = torch.cat(new_image_features, dim=0) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 def get_image_features( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, vision_feature_layer: Union[int, List[int]], vision_feature_select_strategy: str, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( image_size=imsize, grid_pinpoints=self.config.image_grid_pinpoints, patch_size=self.config.vision_config.image_size, ) for imsize in image_sizes ] if pixel_values.dim() == 5: # stacked if input is (batch_size, num_patches, num_channels, height, width) _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] pixel_values = torch.cat(_pixel_values_list, dim=0) elif pixel_values.dim() != 4: # otherwise has to be stacked from list of (num_patches, num_channels, height, width) raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") image_features = self.vision_tower(pixel_values, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): selected_image_feature = image_features.hidden_states[vision_feature_layer] else: hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] selected_image_feature = torch.cat(hs_pool, dim=-1) if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) return image_features def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, image_sizes=None, attention_mask=None, **kwargs, ): """ Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 The only differences are: - add new args token_idx - add the process of merging images into inputs_embeds """ token_idx = kwargs.get("token_idx", None) if token_idx is None: return super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_sizes=image_sizes, attention_mask=attention_mask, **kwargs, ) else: use_flash_attention = kwargs.get("use_flash_attention", False) flash_attention_recompute = kwargs.get("flash_attention_recompute", False) position_ids = kwargs.get("position_ids", None) labels = kwargs.get("labels", None) if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) vision_feature_layer = kwargs.get("vision_feature_layer", None) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images image_features = self.get_image_features( pixel_values, image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" image_features, feature_lens = self.pack_image_features( image_features, image_sizes, vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if inputs_embeds[special_image_mask].numel() != image_features.numel(): n_image_tokens = (input_ids == self.config.image_token_index).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: seq_len = input_ids.shape[1] pad_len = seq_len - token_idx.item() input_ids = torch.index_select(input_ids, 1, token_idx - 1) # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) # Filter out only the tokens that can be un-attended, this can happen # if one uses Llava + Fused modules where the cache on the # first iteration is already big enough, or if one passes custom cache valid_indices = non_attended_tokens < extended_attention_mask.size(-1) new_batch_index = batch_index[valid_indices] new_non_attended_tokens = non_attended_tokens[valid_indices] # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = extended_attention_mask attention_mask[:, -pad_len:] = 0 if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 else: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, "labels": labels, "use_flash_attention": use_flash_attention, "flash_attention_recompute": flash_attention_recompute, } ) return model_inputs