Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

This commit is contained in:
Nicolas Patry 2024-07-01 16:16:21 +00:00
parent 1bd52157d8
commit b686f66727
3 changed files with 10 additions and 6 deletions

View File

@ -11,6 +11,9 @@ class Seqlen:
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths): def __init__(self, input_lengths):
self.set_input_lengths(input_lengths)
def set_input_lengths(self, input_lengths):
self.input_lengths = input_lengths self.input_lengths = input_lengths
if FLASH_DECODING: if FLASH_DECODING:
device = self.input_lengths.device device = self.input_lengths.device
@ -20,8 +23,8 @@ class Seqlen:
device=device, device=device,
dtype=torch.int32, dtype=torch.int32,
) )
cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32) cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
@ -512,8 +513,8 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths.input_lengths = torch.clamp( input_lengths = Seqlen(
input_lengths.input_lengths, max=self.max_past_tensor torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
) )
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)

View File

@ -647,8 +647,8 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths.input_lengths = torch.clamp( input_lengths = Seqlen(
input_lengths.input_lengths, max=self.max_past_tensor torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
) )
hidden_states = self.model( hidden_states = self.model(