From 4b1005c7e16d29ece8cde637b6ab516595eab067 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 30 Jul 2024 17:20:40 +0000 Subject: [PATCH] fix: attempt forward on flash attn2 to check hardware support --- .../layers/attention/cuda.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index c0c4da4d..19ce294b 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -173,6 +173,41 @@ def paged_attention( try: import flash_attn_2_cuda + # try forwarding to see if it works with all dummy inputs + batch_size = 1 + num_heads = 1 + head_dim = 1 + seqlen = 1 + + try: + flash_attn_2_cuda.varlen_fwd( + torch.zeros(batch_size, num_heads, seqlen, head_dim), # q + torch.zeros(batch_size, num_heads, seqlen, head_dim), # k + torch.zeros(batch_size, num_heads, seqlen, head_dim), # v + None, # out (optional) + torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q + torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k + None, # alibi_slopes (optional) + None, # q_padded (optional) + None, # k_padded (optional) + None, # v_padded (optional) + seqlen, # max_seqlen_q + seqlen, # max_seqlen_k + 1.0, # softmax_scale + 0.0, # softmax_lse (default value) + False, # is_causal + True, # return_softmax + -1, # window_size_left + -1, # window_size_right + 0.0, # softmax_softcap + False, # deterministic + None, # rng_state (optional) + ) + except RuntimeError as e: + raise ImportError( + "Flash Attention V2 is not supported on this machine. " f"Error: {e}" + ) from e + V2 = True except ImportError: try: