# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Llava-NeXT model.""" from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def _merge_input_ids_with_image_features( self, inputs_embeds: torch.Tensor, image_features: torch.Tensor, input_ids: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index # Let's pray we have enabled enough slots ! try: inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) except Exception as e: raise RuntimeError( f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" ) return inputs_embeds def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, ): if token_idx is not None: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, ) logits = outputs[0] if not return_dict: output = (logits,) + outputs[1:] return output return outputs def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, image_sizes=None, attention_mask=None, **kwargs, ): """ Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 The only differences are: - add new args token_idx - add the process of merging images into inputs_embeds """ token_idx = kwargs.get("token_idx", None) if token_idx is None: return super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_sizes=image_sizes, attention_mask=attention_mask, **kwargs, ) else: use_flash_attention = kwargs.get("use_flash_attention", False) flash_attention_recompute = kwargs.get("flash_attention_recompute", False) position_ids = kwargs.get("position_ids", None) labels = kwargs.get("labels", None) if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) vision_feature_layer = kwargs.get("vision_feature_layer", None) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images batch_size, num_patches, num_channels, height, width = pixel_values.shape reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) image_features = self.vision_tower( reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, ) selected_image_feature = image_features.hidden_states[vision_feature_layer] if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) # split up image_features for each of the individual images # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) # if we assume each image has 5 image features (base image + 4 patches) split_sizes = [image.shape[0] for image in pixel_values] image_features = torch.split(image_features, split_sizes, dim=0) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] if height * width != base_image_feature.shape[0]: raise ValueError("The number of patches is not consistent with the image size.") num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx].tolist(), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat( ( image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) new_image_features.append(image_feature) image_features = torch.stack(new_image_features, dim=0) inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: seq_len = input_ids.shape[1] pad_len = seq_len - token_idx.item() input_ids = torch.index_select(input_ids, 1, token_idx - 1) # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) # Filter out only the tokens that can be un-attended, this can happen # if one uses Llava + Fused modules where the cache on the # first iteration is already big enough, or if one passes custom cache valid_indices = non_attended_tokens < extended_attention_mask.size(-1) new_batch_index = batch_index[valid_indices] new_non_attended_tokens = non_attended_tokens[valid_indices] # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = extended_attention_mask attention_mask[:, -pad_len:] = 0 if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 else: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, "labels": labels, "use_flash_attention": use_flash_attention, "flash_attention_recompute": flash_attention_recompute, } ) return model_inputs