From b686f667275f473456407c5454704d779dfbdd9d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 16:16:21 +0000 Subject: [PATCH] Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). --- server/text_generation_server/layers/attention/common.py | 7 +++++-- .../models/custom_modeling/flash_mistral_modeling.py | 5 +++-- .../models/custom_modeling/flash_mixtral_modeling.py | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index ca74bdc2..a481b9f0 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -11,6 +11,9 @@ class Seqlen: cu_seqlen_k: Optional[torch.Tensor] def __init__(self, input_lengths): + self.set_input_lengths(input_lengths) + + def set_input_lengths(self, input_lengths): self.input_lengths = input_lengths if FLASH_DECODING: device = self.input_lengths.device @@ -20,8 +23,8 @@ class Seqlen: device=device, dtype=torch.int32, ) - cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32) - cu_seqlen_k[0] = 0 + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cu_seqlen_k[0] = 0 torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 51d9da44..0e73f48d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( + Seqlen, paged_attention, attention, reshape_and_cache, @@ -512,8 +513,8 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths.input_lengths = torch.clamp( - input_lengths.input_lengths, max=self.max_past_tensor + input_lengths = Seqlen( + torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor) ) inputs_embeds = self.embed_tokens(input_ids) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3395e627..540139c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -647,8 +647,8 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths.input_lengths = torch.clamp( - input_lengths.input_lengths, max=self.max_past_tensor + input_lengths = Seqlen( + torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor) ) hidden_states = self.model(