lint fix.

This commit is contained in:
Nicolas Patry 2024-10-15 18:46:56 +02:00
parent 5c8c5ac81a
commit fc41f0784a
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 6 additions and 55 deletions

View File

@ -1,6 +1,4 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
@ -52,53 +50,3 @@ class Seqlen:
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
=======
if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self

View File

@ -243,6 +243,7 @@ if ATTENTION == "flashinfer":
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif ATTENTION == "flashdecoding":
if V2:
@ -351,6 +352,7 @@ elif ATTENTION == "flashdecoding":
None,
)
return out
elif ATTENTION == "paged":
if V2:
@ -459,6 +461,7 @@ elif ATTENTION == "paged":
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")

View File

@ -1679,9 +1679,9 @@ class FlashCausalLM(Model):
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
cache_lengths_tensor
)
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],