Fixup mistral clamping (had issues with cuda graphs).

This commit is contained in:
Nicolas Patry 2024-07-01 16:31:22 +00:00
parent b686f66727
commit ef8bce0b41
3 changed files with 24 additions and 19 deletions

View File

@ -4,18 +4,16 @@ import torch
from typing import Optional from typing import Optional
@dataclass if FLASH_DECODING:
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths): @dataclass
self.set_input_lengths(input_lengths) class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def set_input_lengths(self, input_lengths): def __init__(self, input_lengths):
self.input_lengths = input_lengths self.input_lengths = input_lengths
if FLASH_DECODING:
device = self.input_lengths.device device = self.input_lengths.device
shape = self.input_lengths.shape shape = self.input_lengths.shape
cu_seqlen_q = torch.arange( cu_seqlen_q = torch.arange(
@ -24,11 +22,22 @@ class Seqlen:
dtype=torch.int32, dtype=torch.int32,
) )
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k self.cu_seqlen_k = cu_seqlen_k
else:
self.cu_seqlen_q = None def clamp(self, max):
self.cu_seqlen_k = None return Seqlen(torch.clamp(self.input_lengths, max=max))
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -513,9 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = Seqlen( input_lengths = input_lengths.clamp(max=self.max_past_tensor)
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(

View File

@ -647,9 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = Seqlen( input_lengths = input_lengths.clamp(max=self.max_past_tensor)
torch.clamp(input_lengths.input_lengths, max=self.max_past_tensor)
)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,