mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Fix qwen2.
This commit is contained in:
parent
febc488e0e
commit
078084286a
@ -409,7 +409,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
self.window_size = config.sliding_window
|
self.window_size = config.sliding_window
|
||||||
self.window_size_tensor = (
|
self.window_size_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.window_size is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -431,7 +431,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
slots = slots[prefill_cache_indices]
|
slots = slots[prefill_cache_indices]
|
||||||
elif self.max_past is not None:
|
elif self.window_size is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
seqlen = seqlen.clamp(max=self.window_size_tensor)
|
seqlen = seqlen.clamp(max=self.window_size_tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user