Fixup flashinfer support

This commit is contained in:
Daniël de Kok 2024-10-17 07:56:51 +00:00
parent 07128cc178
commit 7822bfd68f

View File

@ -252,7 +252,9 @@ def attention(
window_left=window_size_left,
)
elif ATTENTION == "flashdecoding":
# If we are using flashdecoding or paged, we always use flash-attn for
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -262,14 +264,15 @@ def attention(
return flash_attn_2_cuda.varlen_fwd(
query,
kv_cache.key,
kv_cache.value,
# flashdecoding: pass the KV caches, paged: pass the KV.
kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
block_tables if ATTENTION == "flashdecoding" else None,
None,
seqlen.max_q,
seqlen.max_k,