Update window size rocm flash decoding

This commit is contained in:
Mohit Sharma 2025-03-14 07:50:11 +00:00
parent b30cdabf68
commit 170a12f331

View File

@ -83,6 +83,8 @@ def paged_attention(
max_k = max_s
import flash_attn_2_cuda
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
@ -102,8 +104,8 @@ def paged_attention(
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
window_size_left, # Window_left
window_size_right, # Window right
softcap,
False, # return softmax
None, # generator