mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing non blocked attentions
This commit is contained in:
parent
5e70158b2c
commit
4fa4da3cb6
@ -5,8 +5,6 @@ import torch
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
@ -54,19 +52,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
cache_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: torch.Tensor
|
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return self
|
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
|
||||||
|
@ -123,7 +123,7 @@ def paged_attention(
|
|||||||
else:
|
else:
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
@ -243,8 +243,8 @@ if ATTENTION == "flashinfer":
|
|||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
elif ATTENTION == "flashdecoding":
|
||||||
elif V2:
|
if V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
@ -302,7 +302,9 @@ else:
|
|||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
@ -349,12 +351,122 @@ else:
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
elif ATTENTION == "paged":
|
||||||
|
if V2:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
|
):
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # block_tables,
|
||||||
|
None,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap=None,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
# Prefill in the cache with every kind of attention, unless we
|
||||||
# have a configuration that requires flash-attention v1, which
|
# have a configuration that requires flash-attention v1, which
|
||||||
# does not support block tables.
|
# does not support block tables.
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
@ -1602,6 +1602,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
print(slots)
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
# in a circular buffer mode.
|
# in a circular buffer mode.
|
||||||
@ -1677,9 +1679,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
cuda_graph["cache_lengths"].zero_()
|
cuda_graph["cache_lengths"].zero_()
|
||||||
cuda_graph["cache_lengths"][
|
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||||
: cache_lengths_tensor.shape[0]
|
cache_lengths_tensor
|
||||||
] = cache_lengths_tensor
|
)
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
@ -5,9 +5,13 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer")
|
||||||
|
default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
|
PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in {
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
}
|
||||||
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
|
Loading…
Reference in New Issue
Block a user