Speedup flashdecoding.

This commit is contained in:
Nicolas Patry 2024-05-24 16:10:42 +00:00
parent ed96a76d67
commit 6bbc843097
2 changed files with 29 additions and 17 deletions

View File

@ -129,7 +129,8 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -174,7 +175,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
)
@ -275,7 +277,8 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -289,7 +292,8 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
)
@ -356,6 +360,23 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
if cu_seqlen_prefill is None:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=inputs_embeds.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths
residual = None
for i, layer in enumerate(self.layers):
@ -368,7 +389,8 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
)

View File

@ -46,7 +46,8 @@ def attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -92,17 +93,6 @@ def attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
max_q = 1
max_k = max_s
import flash_attn_2_cuda