Fixing non blocked attentions

This commit is contained in:
Nicolas Patry 2024-10-15 16:12:00 +02:00
parent 5e70158b2c
commit 4fa4da3cb6
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
4 changed files with 265 additions and 165 deletions

View File

@ -5,8 +5,6 @@ import torch
from typing import Optional from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
@ -54,19 +52,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def clamp(self, max): def clamp(self, max):
# Flash decoding doesn't need to clamp # Flash decoding doesn't need to clamp
return self return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -123,7 +123,7 @@ def paged_attention(
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query) out = torch.empty_like(query)
@ -243,8 +243,8 @@ if ATTENTION == "flashinfer":
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left, window_left=window_size_left,
) )
elif ATTENTION == "flashdecoding":
elif V2: if V2:
def attention( def attention(
q, q,
@ -302,7 +302,9 @@ else:
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
) )
if softcap is not None: if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2") raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads # Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]: if k.shape[1] != q.shape[1]:
@ -349,12 +351,122 @@ else:
None, None,
) )
return out return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we # Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",

View File

@ -1602,6 +1602,8 @@ class FlashCausalLM(Model):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
print(slots)
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
@ -1677,9 +1679,9 @@ class FlashCausalLM(Model):
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_() cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][ cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
: cache_lengths_tensor.shape[0] cache_lengths_tensor
] = cache_lengths_tensor )
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],

View File

@ -5,9 +5,13 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer") ATTENTION = os.getenv("ATTENTION", "flashinfer")
default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in {
"1",
"true",
}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected