mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: attempt forward on flash attn2 to check hardware support
This commit is contained in:
parent
53aec27328
commit
4b1005c7e1
@ -173,6 +173,41 @@ def paged_attention(
|
|||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
# try forwarding to see if it works with all dummy inputs
|
||||||
|
batch_size = 1
|
||||||
|
num_heads = 1
|
||||||
|
head_dim = 1
|
||||||
|
seqlen = 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
|
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
|
||||||
|
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
|
||||||
|
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
|
||||||
|
None, # out (optional)
|
||||||
|
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
|
||||||
|
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
|
||||||
|
None, # alibi_slopes (optional)
|
||||||
|
None, # q_padded (optional)
|
||||||
|
None, # k_padded (optional)
|
||||||
|
None, # v_padded (optional)
|
||||||
|
seqlen, # max_seqlen_q
|
||||||
|
seqlen, # max_seqlen_k
|
||||||
|
1.0, # softmax_scale
|
||||||
|
0.0, # softmax_lse (default value)
|
||||||
|
False, # is_causal
|
||||||
|
True, # return_softmax
|
||||||
|
-1, # window_size_left
|
||||||
|
-1, # window_size_right
|
||||||
|
0.0, # softmax_softcap
|
||||||
|
False, # deterministic
|
||||||
|
None, # rng_state (optional)
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
V2 = True
|
V2 = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user