From aef931ea5d10b2021d6a52daa09369842491ed30 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 20 Apr 2024 21:16:11 +0000 Subject: [PATCH] fix fa2 triton kernel not working with MQA/GQA --- .../utils/flash_attn.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 791d705c..45245357 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -85,6 +85,17 @@ except ImportError as e: logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + total_tokens, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim) + def attention( q, @@ -194,6 +205,16 @@ def attention( None, ) elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: + # NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not + # repeating. + # TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like + num_heads = q.shape[1] + num_kv_heads = k.shape[1] + if num_kv_heads != num_heads: + # Interleave for MQA workaround. + k = repeat_kv(k, num_heads // num_kv_heads) + v = repeat_kv(v, num_heads // num_kv_heads) + output, _ = triton_attention( q, k,