mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Update mistral past.
This commit is contained in:
parent
8fa8cda660
commit
1bd52157d8
@ -512,7 +512,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
input_lengths.input_lengths = torch.clamp(
|
||||||
|
input_lengths.input_lengths, max=self.max_past_tensor
|
||||||
|
)
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -647,7 +647,9 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
input_lengths.input_lengths = torch.clamp(
|
||||||
|
input_lengths.input_lengths, max=self.max_past_tensor
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user