Add Flash decoding kernel ROCm (#2855)

* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* add flash decoding
This commit is contained in:
Mohit Sharma 2025-01-13 15:42:35 +05:30 committed by GitHub
parent 1660154ae6
commit 880ab9c2f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 11 deletions

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
build-flash-attention-v2-cuda: build-flash-attention-v2-cuda:
pip install -U packaging wheel pip install -U packaging wheel

View File

@ -5,6 +5,10 @@ from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from loguru import logger from loguru import logger
import vllm._custom_ops as ops import vllm._custom_ops as ops
@ -73,11 +77,44 @@ def paged_attention(
# limitations under the License. # limitations under the License.
# #
if ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache.value.shape[3] # block_size = kv_cache.value.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
num_kv_heads = kv_cache.key.shape[1] num_kv_heads = kv_cache.key.shape[1]
@ -247,14 +284,15 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
query, query,
key, # flashdecoding: pass the KV caches, paged: pass the KV.
value, kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_k,
None,
None, None,
None, None,
block_tables if ATTENTION == "flashdecoding" else None,
None, None,
seqlen.max_q, seqlen.max_q,
seqlen.max_k, seqlen.max_k,

View File

@ -1663,7 +1663,7 @@ class FlashCausalLM(Model):
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen, max_total_tokens)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
@ -1710,7 +1710,7 @@ class FlashCausalLM(Model):
assert max_total_tokens is not None assert max_total_tokens is not None
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def tunableop_warmup(self, seqlen: int): def tunableop_warmup(self, seqlen: int, max_bt: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
@ -1724,11 +1724,15 @@ class FlashCausalLM(Model):
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen max_s = seqlen
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(seqlen)
block_tables = block_tables.reshape((seqlen, max_bt))
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen, max_k=seqlen,
) )
@ -1738,7 +1742,7 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=block_tables,
seqlen=seqlen, seqlen=seqlen,
slots=slots, slots=slots,
max_s=max_s, max_s=max_s,