Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.
This commit is contained in:
Nicolas Patry 2024-05-17 08:43:33 +00:00
parent cff472ba2b
commit 1b86d0f31d
7 changed files with 260 additions and 123 deletions

View File

@ -1,11 +1,11 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 flash_att_v2_commit_cuda := v2.5.8
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash-attention-v2-cuda: flash-attention-v2-cuda:
# Clone flash attention # Clone flash attention
pip install -U packaging ninja --no-cache-dir pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2
build-flash-attention-v2-cuda: flash-attention-v2-cuda build-flash-attention-v2-cuda: flash-attention-v2-cuda
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)

View File

@ -3,8 +3,9 @@ import torch
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
BLOCK_SIZE: int = 16 BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
# Will be set in warmup # Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None CACHE_MANAGER: Optional["CacheManager"] = None
@ -30,21 +31,38 @@ class CacheManager:
else: else:
x = self.block_size // element_size x = self.block_size // element_size
self.kv_cache = [ if FLASH_DECODING:
( self.kv_cache = [
torch.empty( (
(num_blocks, num_heads, head_size // x, self.block_size, x), torch.empty(
dtype=dtype, (num_blocks, self.block_size, num_heads, head_size),
device=device, dtype=dtype,
), device=device,
torch.empty( ),
(num_blocks, num_heads, head_size, self.block_size), torch.empty(
dtype=dtype, (num_blocks, self.block_size, num_heads, head_size),
device=device, dtype=dtype,
), device=device,
) ),
for _ in range(num_layers) )
] for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange( self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int64 0, num_blocks * self.block_size, dtype=torch.int64

View File

@ -28,6 +28,7 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -151,38 +152,75 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# Prefill if FLASH_DECODING:
if cu_seqlen_prefill is not None: # Prefill
# flash attention kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
flash_attn.attention( :, 0
query, ]
torch.select(kv, dim=1, index=0), kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
torch.select(kv, dim=1, index=1), :, 1
attn_output, ]
cu_seqlen_prefill,
max_s, if cu_seqlen_prefill is not None:
self.softmax_scale, # flash attention
) flash_attn.attention(
# Decode query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else: else:
paged_attention.attention( paged_attention.reshape_and_cache(
attn_output, kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
) )
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -214,44 +215,79 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# Prefill if FLASH_DECODING:
if cu_seqlen_prefill is not None: # Prefill
# flash attention kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
flash_attn.attention( :, 0
query, ]
torch.select(kv, dim=1, index=0), kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
torch.select(kv, dim=1, index=1), :, 1
attn_output, ]
cu_seqlen_prefill,
max_s, if cu_seqlen_prefill is not None:
self.softmax_scale, # flash attention
window_size_left=self.max_past, flash_attn.attention(
) query,
# Decode # torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else: else:
paged_attention.attention( if prefill_cache_indices is not None:
attn_output, kv_to_cache = kv[prefill_cache_indices]
query, else:
kv_cache[0], kv_to_cache = kv
kv_cache[1],
self.kv_head_mapping, paged_attention.reshape_and_cache(
self.softmax_scale, kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
block_tables,
input_lengths,
max_s,
) )
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,9 +1,13 @@
import torch import torch
import os import os
from loguru import logger
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")] cuda_graphs = [int(item) for item in cuda_graphs.split(",")]

View File

@ -134,6 +134,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
v, v,
out, out,
cu_seqlens, cu_seqlens,
block_tables,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
@ -149,7 +150,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None, None,
None, block_tables,
None, None,
max_s, max_s,
max_s, max_s,

View File

@ -85,53 +85,93 @@ def attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if FLASH_DECODING:
if use_v1: cu_seqlen_q = torch.arange(
ops.paged_attention_v1( input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32
out, )
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
max_q = 1
max_k = max_s
import flash_attn_2_cuda
flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
value_cache, value_cache,
kv_head_mapping, out,
softmax_scale, cu_seqlen_q,
block_tables, cu_seqlen_k,
input_lengths, None,
block_size, block_tables,
max_s, None,
max_q,
max_k,
0.0,
softmax_scale,
False,
True,
-1,
0,
False,
None, None,
"auto",
1.0,
) )
else: else:
# Run PagedAttention V2. from vllm._C import ops
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2( use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
out, if use_v1:
exp_sums, ops.paged_attention_v1(
max_logits, out,
tmp_output, query,
query, key_cache,
key_cache, value_cache,
value_cache, kv_head_mapping,
kv_head_mapping, softmax_scale,
softmax_scale, block_tables,
block_tables, input_lengths,
input_lengths, block_size,
block_size, max_s,
max_s, None,
None, "auto",
"auto", 1.0,
1.0, )
) else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)