mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Using flash decoding
Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral.
This commit is contained in:
parent
cff472ba2b
commit
1b86d0f31d
@ -1,11 +1,11 @@
|
|||||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
flash_att_v2_commit_cuda := v2.5.8
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
|
|
||||||
flash-attention-v2-cuda:
|
flash-attention-v2-cuda:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
||||||
|
@ -3,8 +3,9 @@ import torch
|
|||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
@ -30,6 +31,23 @@ class CacheManager:
|
|||||||
else:
|
else:
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
|
||||||
|
if FLASH_DECODING:
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, self.block_size, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, self.block_size, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
@ -28,6 +28,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -151,13 +152,49 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
if FLASH_DECODING:
|
||||||
|
# Prefill
|
||||||
|
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||||
|
:, 0
|
||||||
|
]
|
||||||
|
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||||
|
:, 1
|
||||||
|
]
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
# torch.select(kv, dim=1, index=0),
|
||||||
|
# torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
block_tables,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
paged_attention.reshape_and_cache(
|
||||||
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
@ -167,6 +204,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -214,6 +215,45 @@ class MistralAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
if FLASH_DECODING:
|
||||||
|
# Prefill
|
||||||
|
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||||
|
:, 0
|
||||||
|
]
|
||||||
|
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||||
|
:, 1
|
||||||
|
]
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
# torch.select(kv, dim=1, index=0),
|
||||||
|
# torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
block_tables,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
else:
|
else:
|
||||||
@ -222,10 +262,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
@ -235,9 +271,9 @@ class MistralAttention(torch.nn.Module):
|
|||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
|
if FLASH_DECODING:
|
||||||
|
logger.info("Using FLASH_DECODING")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
|
@ -134,6 +134,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
|||||||
v,
|
v,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
block_tables,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
@ -149,7 +150,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
None,
|
None,
|
||||||
None,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
|
@ -85,6 +85,46 @@ def attention(
|
|||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
|
if FLASH_DECODING:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
||||||
|
),
|
||||||
|
input_lengths.cumsum(dim=-1),
|
||||||
|
]
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
max_q = 1
|
||||||
|
max_k = max_s
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
|
Loading…
Reference in New Issue
Block a user