mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixup mistral clamping (had issues with cuda graphs).
This commit is contained in:
parent
b686f66727
commit
ef8bce0b41
@ -4,18 +4,16 @@ import torch
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
if FLASH_DECODING:
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(self, input_lengths):
|
@dataclass
|
||||||
self.set_input_lengths(input_lengths)
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
def set_input_lengths(self, input_lengths):
|
def __init__(self, input_lengths):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
if FLASH_DECODING:
|
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
cu_seqlen_q = torch.arange(
|
cu_seqlen_q = torch.arange(
|
||||||
@ -24,11 +22,22 @@ class Seqlen:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
else:
|
|
||||||
self.cu_seqlen_q = None
|
def clamp(self, max):
|
||||||
self.cu_seqlen_k = None
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
|
||||||
|
def clamp(self, max):
|
||||||
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
@ -513,9 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = Seqlen(
|
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||||
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -647,9 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = Seqlen(
|
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||||
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user