From 7822bfd68fb9e1291591d22eb968591655e79860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 17 Oct 2024 07:56:51 +0000 Subject: [PATCH] Fixup flashinfer support --- .../text_generation_server/layers/attention/cuda.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 514ad7ec..5846bfe5 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -252,7 +252,9 @@ def attention( window_left=window_size_left, ) - elif ATTENTION == "flashdecoding": + # If we are using flashdecoding or paged, we always use flash-attn for + # the prefill. We have to branch on whether we use flash-attn v1 or v2. + elif V2: out = torch.empty_like(query) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -262,14 +264,15 @@ def attention( return flash_attn_2_cuda.varlen_fwd( query, - kv_cache.key, - kv_cache.value, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, out, seqlen.cu_seqlen_q, seqlen.cu_seqlen_k, None, None, - block_tables, + block_tables if ATTENTION == "flashdecoding" else None, None, seqlen.max_q, seqlen.max_k,