# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Llava-NeXT model.""" from typing import List, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, ) def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (height, width). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (height, width). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (`torch.Tensor`): The image tensor, assumed to be of shape (num_channels, height, width). original_size (`tuple`): The original size of the image (height, width). Returns: `torch.Tensor`: The unpadded image tensor. """ original_height, original_width = original_size current_height, current_width = tensor.shape[1:] original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext class LlavaNextMultiModalProjector(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.linear_1 = TensorParallelColumnLinear.load( prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = TensorParallelRowLinear.load( prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True ) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class LlavaNextForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize vision_config = config.vision_config # Instead of selecting in hidden_states[-2]. # Instead compute only the n -2 + 1 layers and don't pool if config.vision_feature_layer < 0: vision_config.num_hidden_layers += config.vision_feature_layer + 1 else: vision_config.num_hidden_layers = config.vision_feature_layer + 1 self.vision_tower = load_vision_model( prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights, ) self.multi_modal_projector = LlavaNextMultiModalProjector( prefix="multi_modal_projector", config=config, weights=weights ) self.image_newline = weights.get_tensor("image_newline") self.vocab_size = config.text_config.vocab_size self.config = config config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, ) self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) def _merge_input_ids_with_image_features( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index # Let's pray we have enabled enough slots ! try: inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) except Exception as e: raise RuntimeError( f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" ) return inputs_embeds def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused for this model pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" # 1. Extract the input embeddings # 2. Merge text and images num_images, num_patches, channels, height, width = pixel_values.shape pixel_values = pixel_values.view( num_images * num_patches, channels, height, width ) image_features = self.vision_tower(pixel_values) # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] # Already done within the clip model selected_image_feature = image_features.last_hidden_state if self.config.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.config.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise RuntimeError( f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." ) image_features = self.multi_modal_projector(selected_image_feature) # split up image_features for each of the individual images # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) # if we assume each image has 5 image features (base image + 4 patches) split_sizes = [num_patches] * num_images image_features = torch.split(image_features, split_sizes, dim=0) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" height = width = ( self.config.vision_config.image_size // self.config.vision_config.patch_size ) new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] if height * width != base_image_feature.shape[0]: raise ValueError( "The number of patches is not consistent with the image size." ) # Dimensions are intentionally swapped to be bug-compatible with # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat( ( image_feature, self.image_newline[:, None, None].expand( *image_feature.shape[:-1], 1 ), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = torch.cat( (base_image_feature, image_feature), dim=0 ) else: image_feature = image_feature[0] image_feature = torch.cat( (image_feature, self.image_newline[None]), dim=0 ) new_image_features.append(image_feature) image_features = torch.stack(new_image_features, dim=0) inputs_embeds = self._merge_input_ids_with_image_features( input_ids, inputs_embeds, image_features ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits