From 6bbc843097d85b3a8ef705ea271587665736d39a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 24 May 2024 16:10:42 +0000 Subject: [PATCH] Speedup flashdecoding. --- .../custom_modeling/flash_llama_modeling.py | 32 ++++++++++++++++--- .../utils/paged_attention.py | 14 ++------ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index fa3a78f8..bc71d598 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -129,7 +129,8 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ): qkv = self.query_key_value(hidden_states) @@ -174,7 +175,8 @@ class FlashLlamaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) @@ -275,7 +277,8 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -289,7 +292,8 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) @@ -356,6 +360,23 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) + if cu_seqlen_prefill is None: + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, + device=inputs_embeds.device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + else: + cu_seqlen_q = None + cu_seqlen_k = input_lengths residual = None for i, layer in enumerate(self.layers): @@ -368,7 +389,8 @@ class FlashLlamaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index f8af5dc4..e9dd1249 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -46,7 +46,8 @@ def attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -92,17 +93,6 @@ def attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if FLASH_DECODING: - cu_seqlen_q = torch.arange( - input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32 - ) - cu_seqlen_k = torch.cat( - [ - torch.zeros( - (1,), device=input_lengths.device, dtype=input_lengths.dtype - ), - input_lengths.cumsum(dim=-1), - ] - ).to(dtype=torch.int32) max_q = 1 max_k = max_s import flash_attn_2_cuda