mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Add Flash decoding kernel ROCm (#2855)
* (vllm) updated vllm rocm kernels * revert silu * update partition size * remove grouped_topk * (nit) remove log * add flash decoding
This commit is contained in:
parent
1660154ae6
commit
880ab9c2f3
@ -1,5 +1,5 @@
|
|||||||
flash_att_v2_commit_cuda := v2.6.1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
pip install -U packaging wheel
|
pip install -U packaging wheel
|
||||||
|
@ -5,6 +5,10 @@ from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.models.globals import (
|
||||||
|
ATTENTION,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
@ -73,11 +77,44 @@ def paged_attention(
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding":
|
||||||
|
max_q = 1
|
||||||
|
max_k = max_s
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
out = flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query,
|
||||||
|
kv_cache.key,
|
||||||
|
kv_cache.value,
|
||||||
|
None,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
None, # pad_k
|
||||||
|
None,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0, # dropout
|
||||||
|
softmax_scale,
|
||||||
|
False, # zero_tensors
|
||||||
|
True, # causal
|
||||||
|
-1, # Window_left
|
||||||
|
-1, # Window right
|
||||||
|
softcap,
|
||||||
|
False, # return softmax
|
||||||
|
None, # generator
|
||||||
|
)
|
||||||
|
return out[0]
|
||||||
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = kv_cache.value.shape[3]
|
# block_size = kv_cache.value.shape[3]
|
||||||
|
block_size = BLOCK_SIZE
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
|
||||||
num_kv_heads = kv_cache.key.shape[1]
|
num_kv_heads = kv_cache.key.shape[1]
|
||||||
@ -247,14 +284,15 @@ def attention(
|
|||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key,
|
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||||
value,
|
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||||
|
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
block_tables if ATTENTION == "flashdecoding" else None,
|
||||||
None,
|
None,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
seqlen.max_k,
|
seqlen.max_k,
|
||||||
|
@ -1663,7 +1663,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen, max_total_tokens)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
@ -1710,7 +1710,7 @@ class FlashCausalLM(Model):
|
|||||||
assert max_total_tokens is not None
|
assert max_total_tokens is not None
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def tunableop_warmup(self, seqlen: int):
|
def tunableop_warmup(self, seqlen: int, max_bt: int):
|
||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
@ -1724,11 +1724,15 @@ class FlashCausalLM(Model):
|
|||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
max_s = seqlen
|
max_s = seqlen
|
||||||
|
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(seqlen)
|
||||||
|
block_tables = block_tables.reshape((seqlen, max_bt))
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
max_q=1,
|
|
||||||
max_k=seqlen,
|
max_k=seqlen,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1738,7 +1742,7 @@ class FlashCausalLM(Model):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=block_tables,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
Loading…
Reference in New Issue
Block a user