mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Update mistral past.
This commit is contained in:
parent
8fa8cda660
commit
1bd52157d8
@ -512,7 +512,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths.input_lengths = torch.clamp(
|
||||
input_lengths.input_lengths, max=self.max_past_tensor
|
||||
)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -647,7 +647,9 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths.input_lengths = torch.clamp(
|
||||
input_lengths.input_lengths, max=self.max_past_tensor
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user