mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Using flash decoding
Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral.
This commit is contained in:
parent
cff472ba2b
commit
1b86d0f31d
@ -1,11 +1,11 @@
|
||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||
flash_att_v2_commit_cuda := v2.5.8
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
|
||||
|
||||
flash-attention-v2-cuda:
|
||||
# Clone flash attention
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2
|
||||
|
||||
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
||||
|
@ -3,8 +3,9 @@ import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
@ -30,21 +31,38 @@ class CacheManager:
|
||||
else:
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
if FLASH_DECODING:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, self.block_size, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, self.block_size, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int64
|
||||
|
@ -28,6 +28,7 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -151,38 +152,75 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
if FLASH_DECODING:
|
||||
# Prefill
|
||||
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 0
|
||||
]
|
||||
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 1
|
||||
]
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
# torch.select(kv, dim=1, index=0),
|
||||
# torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
block_tables,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
None,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -214,44 +215,79 @@ class MistralAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
if FLASH_DECODING:
|
||||
# Prefill
|
||||
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 0
|
||||
]
|
||||
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 1
|
||||
]
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
# torch.select(kv, dim=1, index=0),
|
||||
# torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
block_tables,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
None,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
@ -1,9 +1,13 @@
|
||||
import torch
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
if FLASH_DECODING:
|
||||
logger.info("Using FLASH_DECODING")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
|
@ -134,6 +134,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
@ -149,7 +150,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
|
@ -85,53 +85,93 @@ def attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
if FLASH_DECODING:
|
||||
cu_seqlen_q = torch.arange(
|
||||
input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen_k = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
||||
),
|
||||
input_lengths.cumsum(dim=-1),
|
||||
]
|
||||
).to(dtype=torch.int32)
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
out,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
-1,
|
||||
0,
|
||||
False,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user