mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: attempt forward on flash attn2 to check hardware support
This commit is contained in:
parent
53aec27328
commit
4b1005c7e1
@ -173,6 +173,41 @@ def paged_attention(
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
|
||||
# try forwarding to see if it works with all dummy inputs
|
||||
batch_size = 1
|
||||
num_heads = 1
|
||||
head_dim = 1
|
||||
seqlen = 1
|
||||
|
||||
try:
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
|
||||
None, # out (optional)
|
||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
|
||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
|
||||
None, # alibi_slopes (optional)
|
||||
None, # q_padded (optional)
|
||||
None, # k_padded (optional)
|
||||
None, # v_padded (optional)
|
||||
seqlen, # max_seqlen_q
|
||||
seqlen, # max_seqlen_k
|
||||
1.0, # softmax_scale
|
||||
0.0, # softmax_lse (default value)
|
||||
False, # is_causal
|
||||
True, # return_softmax
|
||||
-1, # window_size_left
|
||||
-1, # window_size_right
|
||||
0.0, # softmax_softcap
|
||||
False, # deterministic
|
||||
None, # rng_state (optional)
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
|
||||
) from e
|
||||
|
||||
V2 = True
|
||||
except ImportError:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user