fix: attempt forward on flash attn2 to check hardware support

This commit is contained in:
drbh 2024-07-30 17:20:40 +00:00
parent 53aec27328
commit 4b1005c7e1

View File

@ -173,6 +173,41 @@ def paged_attention(
try:
import flash_attn_2_cuda
# try forwarding to see if it works with all dummy inputs
batch_size = 1
num_heads = 1
head_dim = 1
seqlen = 1
try:
flash_attn_2_cuda.varlen_fwd(
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
None, # out (optional)
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
None, # alibi_slopes (optional)
None, # q_padded (optional)
None, # k_padded (optional)
None, # v_padded (optional)
seqlen, # max_seqlen_q
seqlen, # max_seqlen_k
1.0, # softmax_scale
0.0, # softmax_lse (default value)
False, # is_causal
True, # return_softmax
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_softcap
False, # deterministic
None, # rng_state (optional)
)
except RuntimeError as e:
raise ImportError(
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
) from e
V2 = True
except ImportError:
try: