From c6f97fd8842004301d4243fa517d02a74a0c5270 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Fri, 21 Mar 2025 05:50:08 +0000 Subject: [PATCH] Change default values Signed-off-by: yuanwu --- .../models/custom_modeling/llava_next.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index 46ad655c..00ecdf95 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -110,8 +110,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, + use_flash_attention: Optional[bool] = True, + flash_attention_recompute: Optional[bool] = True, ): if token_idx is not None: @@ -337,8 +337,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): **kwargs, ) else: - use_flash_attention = kwargs.get("use_flash_attention", False) - flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + use_flash_attention = kwargs.get("use_flash_attention", True) + flash_attention_recompute = kwargs.get("flash_attention_recompute", True) position_ids = kwargs.get("position_ids", None) labels = kwargs.get("labels", None)