Fix qwen2.

This commit is contained in:
Nicolas Patry 2025-03-18 10:36:54 +01:00
parent febc488e0e
commit 078084286a
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9

View File

@ -409,7 +409,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
self.window_size = config.sliding_window
self.window_size_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
if self.window_size is not None
else None
)
@ -431,7 +431,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
elif self.window_size is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.window_size_tensor)