mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing non blocked attentions
This commit is contained in:
parent
5e70158b2c
commit
4fa4da3cb6
@ -5,68 +5,50 @@ import torch
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
cache_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
@dataclass
|
def __init__(
|
||||||
class Seqlen:
|
self,
|
||||||
input_lengths: torch.Tensor
|
input_lengths,
|
||||||
cache_lengths: torch.Tensor
|
cache_lengths,
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q=None,
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
max_q=None,
|
||||||
max_q: int
|
max_k=None,
|
||||||
max_k: int
|
):
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.cache_lengths = cache_lengths
|
||||||
|
device = self.input_lengths.device
|
||||||
|
shape = self.input_lengths.shape
|
||||||
|
if cu_seqlen_q is None:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
shape[0] + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
max_q = 1
|
||||||
|
else:
|
||||||
|
assert max_q is not None
|
||||||
|
assert max_k is not None
|
||||||
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
def __init__(
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
self,
|
# Although FA2 might not want the clamping
|
||||||
input_lengths,
|
# cu_seqlen_k[0] = 0
|
||||||
cache_lengths,
|
total = self.input_lengths + self.cache_lengths
|
||||||
cu_seqlen_q=None,
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
max_q=None,
|
|
||||||
max_k=None,
|
|
||||||
):
|
|
||||||
self.input_lengths = input_lengths
|
|
||||||
self.cache_lengths = cache_lengths
|
|
||||||
device = self.input_lengths.device
|
|
||||||
shape = self.input_lengths.shape
|
|
||||||
if cu_seqlen_q is None:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
shape[0] + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
max_q = 1
|
|
||||||
else:
|
|
||||||
assert max_q is not None
|
|
||||||
assert max_k is not None
|
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
# Although FA2 might not want the clamping
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
# cu_seqlen_k[0] = 0
|
self.max_q = max_q
|
||||||
total = self.input_lengths + self.cache_lengths
|
self.max_k = max_k
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
def clamp(self, max):
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
# Flash decoding doesn't need to clamp
|
||||||
self.max_q = max_q
|
return self
|
||||||
self.max_k = max_k
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
# Flash decoding doesn't need to clamp
|
|
||||||
return self
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
cache_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: torch.Tensor
|
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return self
|
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
|
||||||
|
@ -123,7 +123,7 @@ def paged_attention(
|
|||||||
else:
|
else:
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
@ -243,118 +243,230 @@ if ATTENTION == "flashinfer":
|
|||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
elif ATTENTION == "flashdecoding":
|
||||||
|
if V2:
|
||||||
|
|
||||||
elif V2:
|
def attention(
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
q,
|
||||||
key_cache,
|
key_cache: torch.Tensor,
|
||||||
value_cache,
|
value_cache: torch.Tensor,
|
||||||
out,
|
seqlen: Seqlen,
|
||||||
seqlen.cu_seqlen_q,
|
block_tables: torch.Tensor,
|
||||||
seqlen.cu_seqlen_k,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
block_tables,
|
|
||||||
None,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
window_size_left=-1,
|
||||||
causal,
|
causal=True,
|
||||||
window_size_left,
|
softcap=0.0,
|
||||||
0,
|
):
|
||||||
softcap,
|
out = torch.empty_like(q)
|
||||||
False,
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
None,
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
)[0]
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
window_size_left: int = -1,
|
window_size_left: int = -1,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap=None,
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
if softcap is not None:
|
return out
|
||||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
elif ATTENTION == "paged":
|
||||||
|
if V2:
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
def attention(
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
out,
|
seqlen: Seqlen,
|
||||||
seqlen.cu_seqlen_q,
|
block_tables: torch.Tensor,
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
window_size_left=-1,
|
||||||
causal,
|
causal=True,
|
||||||
False,
|
softcap=0.0,
|
||||||
0,
|
):
|
||||||
None,
|
out = torch.empty_like(q)
|
||||||
)
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
return out
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # block_tables,
|
||||||
|
None,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap=None,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
# Prefill in the cache with every kind of attention, unless we
|
||||||
# have a configuration that requires flash-attention v1, which
|
# have a configuration that requires flash-attention v1, which
|
||||||
# does not support block tables.
|
# does not support block tables.
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
@ -1602,6 +1602,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
print(slots)
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
# in a circular buffer mode.
|
# in a circular buffer mode.
|
||||||
@ -1677,9 +1679,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
cuda_graph["cache_lengths"].zero_()
|
cuda_graph["cache_lengths"].zero_()
|
||||||
cuda_graph["cache_lengths"][
|
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||||
: cache_lengths_tensor.shape[0]
|
cache_lengths_tensor
|
||||||
] = cache_lengths_tensor
|
)
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
@ -5,9 +5,13 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
||||||
|
default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
|
PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in {
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
}
|
||||||
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
|
Loading…
Reference in New Issue
Block a user