mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).
This commit is contained in:
parent
1bd52157d8
commit
b686f66727
@ -11,6 +11,9 @@ class Seqlen:
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, input_lengths):
|
||||
self.set_input_lengths(input_lengths)
|
||||
|
||||
def set_input_lengths(self, input_lengths):
|
||||
self.input_lengths = input_lengths
|
||||
if FLASH_DECODING:
|
||||
device = self.input_lengths.device
|
||||
@ -20,8 +23,8 @@ class Seqlen:
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
cu_seqlen_k[0] = 0
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
# cu_seqlen_k[0] = 0
|
||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
@ -512,8 +513,8 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths.input_lengths = torch.clamp(
|
||||
input_lengths.input_lengths, max=self.max_past_tensor
|
||||
input_lengths = Seqlen(
|
||||
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
|
||||
)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -647,8 +647,8 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths.input_lengths = torch.clamp(
|
||||
input_lengths.input_lengths, max=self.max_past_tensor
|
||||
input_lengths = Seqlen(
|
||||
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
|
Loading…
Reference in New Issue
Block a user