From 078084286a4ebeb0c7b27d058638dacc82df63f6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 18 Mar 2025 10:36:54 +0100 Subject: [PATCH] Fix qwen2. --- .../models/custom_modeling/flash_qwen2_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index c06e5dcc..f5e4e15c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -409,7 +409,7 @@ class Qwen2ForCausalLM(torch.nn.Module): self.window_size = config.sliding_window self.window_size_tensor = ( torch.tensor(config.sliding_window, device=weights.device) - if self.max_past is not None + if self.window_size is not None else None ) @@ -431,7 +431,7 @@ class Qwen2ForCausalLM(torch.nn.Module): if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: + elif self.window_size is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values seqlen = seqlen.clamp(max=self.window_size_tensor)