From 1bd52157d87f09a8ca7edc8a7fe4ed2a81e23ce1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 13:19:26 +0000 Subject: [PATCH] Update mistral past. --- .../models/custom_modeling/flash_mistral_modeling.py | 4 +++- .../models/custom_modeling/flash_mixtral_modeling.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8faf8ed0..51d9da44 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -512,7 +512,9 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths.input_lengths = torch.clamp( + input_lengths.input_lengths, max=self.max_past_tensor + ) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 564bee37..3395e627 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -647,7 +647,9 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths.input_lengths = torch.clamp( + input_lengths.input_lengths, max=self.max_past_tensor + ) hidden_states = self.model( input_ids,