mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
lint fix.
This commit is contained in:
parent
5c8c5ac81a
commit
fc41f0784a
@ -1,6 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
@ -52,53 +50,3 @@ class Seqlen:
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
=======
|
||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_lengths,
|
||||
prefix_lengths,
|
||||
cu_seqlen_q=None,
|
||||
max_q=None,
|
||||
max_k=None,
|
||||
):
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_lengths = prefix_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
if cu_seqlen_q is None:
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
max_q = 1
|
||||
else:
|
||||
assert max_q is not None
|
||||
assert max_k is not None
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
total = self.input_lengths + self.prefix_lengths
|
||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
self.cu_seqlen_k = cu_seqlen_k
|
||||
self.max_q = max_q
|
||||
self.max_k = max_k
|
||||
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
@ -243,6 +243,7 @@ if ATTENTION == "flashinfer":
|
||||
sm_scale=softmax_scale,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
elif ATTENTION == "flashdecoding":
|
||||
if V2:
|
||||
|
||||
@ -351,6 +352,7 @@ elif ATTENTION == "flashdecoding":
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
elif ATTENTION == "paged":
|
||||
if V2:
|
||||
|
||||
@ -459,6 +461,7 @@ elif ATTENTION == "paged":
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||
|
||||
|
@ -1679,9 +1679,9 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||
cache_lengths_tensor
|
||||
)
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
|
Loading…
Reference in New Issue
Block a user