mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Hotfixing qwen2 and starcoder2 (which also get clamping).
This commit is contained in:
parent
963b6c6f0f
commit
57541d5e88
@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user