mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
lint fix.
This commit is contained in:
parent
5c8c5ac81a
commit
fc41f0784a
@ -1,6 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -52,53 +50,3 @@ class Seqlen:
|
|||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
=======
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
prefix_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_lengths,
|
|
||||||
prefix_lengths,
|
|
||||||
cu_seqlen_q=None,
|
|
||||||
max_q=None,
|
|
||||||
max_k=None,
|
|
||||||
):
|
|
||||||
self.input_lengths = input_lengths
|
|
||||||
self.prefix_lengths = prefix_lengths
|
|
||||||
device = self.input_lengths.device
|
|
||||||
shape = self.input_lengths.shape
|
|
||||||
if cu_seqlen_q is None:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
shape[0] + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
max_q = 1
|
|
||||||
else:
|
|
||||||
assert max_q is not None
|
|
||||||
assert max_k is not None
|
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
|
||||||
# Although FA2 might not want the clamping
|
|
||||||
# cu_seqlen_k[0] = 0
|
|
||||||
total = self.input_lengths + self.prefix_lengths
|
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
|
||||||
self.max_q = max_q
|
|
||||||
self.max_k = max_k
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
# Flash decoding doesn't need to clamp
|
|
||||||
return self
|
|
@ -243,6 +243,7 @@ if ATTENTION == "flashinfer":
|
|||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif ATTENTION == "flashdecoding":
|
elif ATTENTION == "flashdecoding":
|
||||||
if V2:
|
if V2:
|
||||||
|
|
||||||
@ -351,6 +352,7 @@ elif ATTENTION == "flashdecoding":
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
elif ATTENTION == "paged":
|
elif ATTENTION == "paged":
|
||||||
if V2:
|
if V2:
|
||||||
|
|
||||||
@ -459,6 +461,7 @@ elif ATTENTION == "paged":
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||||
|
|
||||||
|
@ -1679,9 +1679,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
cuda_graph["cache_lengths"].zero_()
|
cuda_graph["cache_lengths"].zero_()
|
||||||
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
cuda_graph["cache_lengths"][
|
||||||
cache_lengths_tensor
|
: cache_lengths_tensor.shape[0]
|
||||||
)
|
] = cache_lengths_tensor
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
Loading…
Reference in New Issue
Block a user