From ef8bce0b4131ebbc7ccd5dbdfc60908649b730ab Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 16:31:22 +0000 Subject: [PATCH] Fixup mistral clamping (had issues with cuda graphs). --- .../layers/attention/common.py | 35 ++++++++++++------- .../custom_modeling/flash_mistral_modeling.py | 4 +-- .../custom_modeling/flash_mixtral_modeling.py | 4 +-- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index a481b9f0..ce9f5f32 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -4,18 +4,16 @@ import torch from typing import Optional -@dataclass -class Seqlen: - input_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] +if FLASH_DECODING: - def __init__(self, input_lengths): - self.set_input_lengths(input_lengths) + @dataclass + class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] - def set_input_lengths(self, input_lengths): - self.input_lengths = input_lengths - if FLASH_DECODING: + def __init__(self, input_lengths): + self.input_lengths = input_lengths device = self.input_lengths.device shape = self.input_lengths.shape cu_seqlen_q = torch.arange( @@ -24,11 +22,22 @@ class Seqlen: dtype=torch.int32, ) cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_k = cu_seqlen_k - else: - self.cu_seqlen_q = None - self.cu_seqlen_k = None + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 0e73f48d..69ed5f64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -513,9 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = Seqlen( - torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor) - ) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 540139c3..2d6a7f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -647,9 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = Seqlen( - torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor) - ) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids,