diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index 70449f6b7..00ecdf952 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -20,6 +20,7 @@ import torch import torch.utils.checkpoint import numpy as np +from loguru import logger from transformers.models.llava_next.modeling_llava_next import ( unpad_image, ) @@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - def _merge_input_ids_with_image_features( - self, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - input_ids: torch.Tensor, - ): - """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index - # Let's pray we have enabled enough slots ! - try: - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - except Exception as e: - raise RuntimeError( - f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" - ) - return inputs_embeds - def forward( self, input_ids: torch.LongTensor = None, @@ -169,6 +153,92 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return outputs + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 + def pack_image_features( + self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None, + ): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Args: + image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + if ( + np.prod(image_feature.shape) + % (num_patch_height * num_patch_width * height * width) + != 0 + and vision_feature_select_strategy == "default" + ): + logger.warning_once( + "Image feature shape does not line up with the provided patch size. " + "You may be using the `default` vision_feature_select_strategy with a" + " visual encoder that does not have CLS." + ) + + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature)), dim=0 + ) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor( + feature_lens, dtype=torch.long, device=image_features.device + ) + return image_features, feature_lens + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 def get_image_features( self, @@ -303,61 +373,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, ) - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx].tolist(), - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute( - 4, 0, 2, 1, 3 - ).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx] - ) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.cat(new_image_features, dim=0) - inputs_embeds = self._merge_input_ids_with_image_features( - inputs_embeds, image_features, input_ids + special_image_mask = ( + input_ids == self.config.image_token_index + ).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device ) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache elif past_key_values is not None: diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 66e00171b..543b07e8e 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -428,6 +428,9 @@ class VlmCausalLMBatch(CausalLMBatch): else: images.append(curr_image) + if is_warmup is True: + images += [images[0]] * (len(texts) - len(images)) + missing_inputs = 0 dummy_images = None if is_warmup is False: @@ -1464,7 +1467,6 @@ class VlmCausalLM(Model): batch = self.batch_from_pb(request.batch, is_warmup=True) max_input_tokens = request.max_input_tokens max_prefill_batch_size = batch.input_ids.shape[0] - try: # max prefill batch size warmup _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) @@ -1548,7 +1550,7 @@ class VlmCausalLM(Model): request, PREFILL_WARMUP_SEQLEN_LIST[0] - 1, max_prefill_batch_size, - is_warmup=False, + is_warmup=True, ) _, prefill_batch, _ = self.generate_token( [batch], is_warmup=True @@ -1568,7 +1570,7 @@ class VlmCausalLM(Model): request, PREFILL_WARMUP_SEQLEN_LIST[0] - 1, 2, - is_warmup=False, + is_warmup=True, ) _, prefill_batch, _ = self.generate_token( [batch], is_warmup=True