diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 514ad7ec..5846bfe5 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -252,7 +252,9 @@ def attention( window_left=window_size_left, ) - elif ATTENTION == "flashdecoding": + # If we are using flashdecoding or paged, we always use flash-attn for + # the prefill. We have to branch on whether we use flash-attn v1 or v2. + elif V2: out = torch.empty_like(query) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -262,14 +264,15 @@ def attention( return flash_attn_2_cuda.varlen_fwd( query, - kv_cache.key, - kv_cache.value, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, out, seqlen.cu_seqlen_q, seqlen.cu_seqlen_k, None, None, - block_tables, + block_tables if ATTENTION == "flashdecoding" else None, None, seqlen.max_q, seqlen.max_k,