mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixup flashinfer support
This commit is contained in:
parent
07128cc178
commit
7822bfd68f
@ -252,7 +252,9 @@ def attention(
|
|||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif ATTENTION == "flashdecoding":
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||||
|
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
|
||||||
|
elif V2:
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
@ -262,14 +264,15 @@ def attention(
|
|||||||
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||||
kv_cache.value,
|
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||||
|
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables if ATTENTION == "flashdecoding" else None,
|
||||||
None,
|
None,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
seqlen.max_k,
|
seqlen.max_k,
|
||||||
|
Loading…
Reference in New Issue
Block a user