diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 682aade26..518e55eee 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -83,6 +83,8 @@ def paged_attention( max_k = max_s import flash_attn_2_cuda + window_size_right = -1 if window_size_left == -1 else 0 + if softcap is None: softcap = 0.0 out = flash_attn_2_cuda.varlen_fwd( @@ -102,8 +104,8 @@ def paged_attention( softmax_scale, False, # zero_tensors True, # causal - -1, # Window_left - -1, # Window right + window_size_left, # Window_left + window_size_right, # Window right softcap, False, # return softmax None, # generator