Hotfixing qwen2 and starcoder2 (which also get clamping).

This commit is contained in:
Nicolas Patry 2024-07-02 14:26:16 +02:00
parent 963b6c6f0f
commit 57541d5e88
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
2 changed files with 2 additions and 2 deletions

View File

@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,

View File

@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,