0.0 is the null value in the C++ API.

This commit is contained in:
Nicolas Patry 2024-07-22 15:59:09 +00:00
parent c4b78bd214
commit 5266f15ae1
No known key found for this signature in database
GPG Key ID: B154A218C20EBBCA

View File

@ -84,6 +84,8 @@ def paged_attention(
# by the current path # by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd( out2 = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
@ -211,7 +213,7 @@ if V2:
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=None, softcap=0.0,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")