From bb6200503cb2b8fc7b48d333709b802f0dd1fb05 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 15 Dec 2023 01:18:39 +0100 Subject: [PATCH] fix: max_past default value must be -1, not 0 (#1348) --- .../models/custom_modeling/flash_mistral_modeling.py | 2 +- .../models/custom_modeling/flash_mixtral_modeling.py | 2 +- server/text_generation_server/utils/flash_attn.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index b97866f7..afeaf7e5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -149,7 +149,7 @@ class MistralAttention(torch.nn.Module): ): super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 + config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index ff2ed9fd..35bb3735 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -204,7 +204,7 @@ class MixtralAttention(torch.nn.Module): ): super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 + config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 3237df82..02f01e65 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -72,6 +72,9 @@ def attention( softmax_scale, window_size_left=-1, ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q,