fix fa2 triton kernel not working with MQA/GQA

This commit is contained in:
fxmarty 2024-04-20 21:16:11 +00:00
parent 325f9774fe
commit aef931ea5d

View File

@ -85,6 +85,17 @@ except ImportError as e:
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
def attention(
q,
@ -194,6 +205,16 @@ def attention(
None,
)
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
# repeating.
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
num_heads = q.shape[1]
num_kv_heads = k.shape[1]
if num_kv_heads != num_heads:
# Interleave for MQA workaround.
k = repeat_kv(k, num_heads // num_kv_heads)
v = repeat_kv(v, num_heads // num_kv_heads)
output, _ = triton_attention(
q,
k,