From 3c6630c6e95e207690a6bef674cd7111938b2825 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Fri, 21 Mar 2025 05:47:18 +0000 Subject: [PATCH] Fix the errors of style check Signed-off-by: yuanwu --- .../models/custom_modeling/llava_next.py | 338 ++++++++++-------- .../models/vlm_causal_lm.py | 4 +- 2 files changed, 200 insertions(+), 142 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index 52d421f7..46ad655c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -27,6 +27,7 @@ from transformers.models.llava_next.modeling_llava_next import ( from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. @@ -49,6 +50,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): """ @@ -72,7 +74,9 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + raise TypeError( + f"image_size invalid type {type(image_size)} with value {image_size}" + ) image_size = image_size.tolist() best_resolution = select_best_resolution(image_size, grid_pinpoints) @@ -86,6 +90,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): num_patches += 1 return num_patches + class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def forward( @@ -110,11 +115,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ): if token_idx is not None: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -140,8 +153,14 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return outputs - #Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 + def pack_image_features( + self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None, + ): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -165,7 +184,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], @@ -174,7 +196,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ) if ( - np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 + np.prod(image_feature.shape) + % (num_patch_height * num_patch_width * height * width) + != 0 and vision_feature_select_strategy == "default" ): logger.warning_once( @@ -183,7 +207,9 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): " visual encoder that does not have CLS." ) - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) @@ -202,11 +228,15 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): else: image_feature = image_feature[0] if image_newline is not None: - image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature)), dim=0 + ) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + feature_lens = torch.tensor( + feature_lens, dtype=torch.long, device=image_features.device + ) return image_features, feature_lens # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 @@ -247,11 +277,16 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ] if pixel_values.dim() == 5: # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + _pixel_values_list = [ + pix_val[:num_patch] + for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] pixel_values = torch.cat(_pixel_values_list, dim=0) elif pixel_values.dim() != 4: # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + raise ValueError( + f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" + ) image_features = self.vision_tower(pixel_values, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, @@ -259,7 +294,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): if isinstance(vision_feature_layer, int): selected_image_feature = image_features.hidden_states[vision_feature_layer] else: - hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + hs_pool = [ + image_features.hidden_states[layer_idx] + for layer_idx in vision_feature_layer + ] selected_image_feature = torch.cat(hs_pool, dim=-1) if vision_feature_select_strategy == "default": @@ -272,138 +310,158 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return image_features def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + else: + use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if ( + past_key_values is None + and pixel_values is not None + and input_ids.shape[1] != 1 + ): + vision_feature_select_strategy = kwargs.get( + "vision_feature_select_strategy", None + ) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_feature_layer ) - else: - use_flash_attention = kwargs.get("use_flash_attention", False) - flash_attention_recompute = kwargs.get("flash_attention_recompute", False) - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: - vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + special_image_mask = ( + input_ids == self.config.image_token_index + ).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx.item() + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where( + first_layer_past_key_value.float().sum(-2) == 0 + ) + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = ( + torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 ) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, + } + ) - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs + return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index a2a3b3dd..543b07e8 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -427,10 +427,10 @@ class VlmCausalLMBatch(CausalLMBatch): images.append([curr_image]) else: images.append(curr_image) - + if is_warmup is True: images += [images[0]] * (len(texts) - len(images)) - + missing_inputs = 0 dummy_images = None if is_warmup is False: