(fix) flashinfer

This commit is contained in:
Mohit Sharma 2025-03-13 21:32:38 +00:00
parent ff82f0f84c
commit 69e0a87dd5

View File

@ -80,7 +80,7 @@ def paged_attention(
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_size_left=window_size_left,
window_left=window_size_left,
)
elif ATTENTION == "flashdecoding":
max_q = 1
@ -257,7 +257,7 @@ def attention(
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_size_left=window_size_left,
window_left=window_size_left,
)
# If we are using flashdecoding or paged, we always use flash-attn for