mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing non blocked attentions
This commit is contained in:
parent
5e70158b2c
commit
4fa4da3cb6
@ -5,10 +5,8 @@ import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
@ -54,19 +52,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: torch.Tensor
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
|
@ -123,7 +123,7 @@ def paged_attention(
|
||||
else:
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
out = torch.empty_like(query)
|
||||
@ -243,8 +243,8 @@ if ATTENTION == "flashinfer":
|
||||
sm_scale=softmax_scale,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
elif V2:
|
||||
elif ATTENTION == "flashdecoding":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
@ -284,7 +284,7 @@ elif V2:
|
||||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
@ -302,7 +302,9 @@ else:
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
@ -349,12 +351,122 @@ else:
|
||||
None,
|
||||
)
|
||||
return out
|
||||
elif ATTENTION == "paged":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
None, # block_tables,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
|
@ -1602,6 +1602,8 @@ class FlashCausalLM(Model):
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
print(slots)
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
@ -1677,9 +1679,9 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||
cache_lengths_tensor
|
||||
)
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
|
@ -5,9 +5,13 @@ from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
||||
default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
|
Loading…
Reference in New Issue
Block a user