Update mistral past.

This commit is contained in:
Nicolas Patry 2024-07-01 13:19:26 +00:00
parent 8fa8cda660
commit 1bd52157d8
2 changed files with 6 additions and 2 deletions

View File

@ -512,7 +512,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths.input_lengths = torch.clamp(
input_lengths.input_lengths, max=self.max_past_tensor
)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -647,7 +647,9 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths.input_lengths = torch.clamp(
input_lengths.input_lengths, max=self.max_past_tensor
)
hidden_states = self.model(
input_ids,