mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Add Flash decoding kernel ROCm (#2855)
* (vllm) updated vllm rocm kernels * revert silu * update partition size * remove grouped_topk * (nit) remove log * add flash decoding
This commit is contained in:
parent
1660154ae6
commit
880ab9c2f3
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||
flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
|
@ -5,6 +5,10 @@ from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from loguru import logger
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
@ -73,11 +77,44 @@ def paged_attention(
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if ATTENTION == "flashdecoding":
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
out = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
None,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None, # pad_k
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0, # dropout
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
softcap,
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out[0]
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache.value.shape[3]
|
||||
# block_size = kv_cache.value.shape[3]
|
||||
block_size = BLOCK_SIZE
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = kv_cache.key.shape[1]
|
||||
@ -247,14 +284,15 @@ def attention(
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
block_tables if ATTENTION == "flashdecoding" else None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
|
@ -1663,7 +1663,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
for seqlen in tuning_sequences:
|
||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen)
|
||||
self.tunableop_warmup(seqlen, max_total_tokens)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
@ -1710,7 +1710,7 @@ class FlashCausalLM(Model):
|
||||
assert max_total_tokens is not None
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
def tunableop_warmup(self, seqlen: int, max_bt: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
@ -1724,11 +1724,15 @@ class FlashCausalLM(Model):
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(seqlen)
|
||||
block_tables = block_tables.reshape((seqlen, max_bt))
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=1,
|
||||
max_k=seqlen,
|
||||
)
|
||||
|
||||
@ -1738,7 +1742,7 @@ class FlashCausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
block_tables=block_tables,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=max_s,
|
||||
|
Loading…
Reference in New Issue
Block a user