mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Speedup flashdecoding.
This commit is contained in:
parent
ed96a76d67
commit
6bbc843097
@ -129,7 +129,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -174,7 +175,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,7 +277,8 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -289,7 +292,8 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,6 +360,23 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
if cu_seqlen_prefill is None:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
input_lengths.shape[0] + 1,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
||||||
|
),
|
||||||
|
input_lengths.cumsum(dim=-1),
|
||||||
|
]
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = None
|
||||||
|
cu_seqlen_k = input_lengths
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -368,7 +389,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +46,8 @@ def attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
cu_seqlen_q: torch.Tensor,
|
||||||
|
cu_seqlen_k: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -92,17 +93,6 @@ def attention(
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
|
||||||
),
|
|
||||||
input_lengths.cumsum(dim=-1),
|
|
||||||
]
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
max_q = 1
|
max_q = 1
|
||||||
max_k = max_s
|
max_k = max_s
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
Loading…
Reference in New Issue
Block a user