diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 726146d1..1ef55019 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -87,6 +87,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ): if token_idx is not None: @@ -109,7 +110,7 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): return_dict=return_dict, token_idx=token_idx, use_flash_attention=use_flash_attention, - flash_attention_recompute=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) logits = outputs[0] @@ -149,6 +150,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): ) else: use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + position_ids = kwargs.get("position_ids", None) labels = kwargs.get("labels", None) if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: @@ -169,7 +172,10 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): batch_size, num_patches, num_channels, height, width = pixel_values.shape reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) image_features = self.vision_tower( - reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention + reshaped_pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) selected_image_feature = image_features.hidden_states[vision_feature_layer] @@ -283,7 +289,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): "token_idx": token_idx, "labels": labels, "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, } ) - return model_inputs \ No newline at end of file + return model_inputs