mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
fix fa2 triton kernel not working with MQA/GQA
This commit is contained in:
parent
325f9774fe
commit
aef931ea5d
@ -85,6 +85,17 @@ except ImportError as e:
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim)
|
||||
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
@ -194,6 +205,16 @@ def attention(
|
||||
None,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
|
||||
# repeating.
|
||||
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
|
||||
num_heads = q.shape[1]
|
||||
num_kv_heads = k.shape[1]
|
||||
if num_kv_heads != num_heads:
|
||||
# Interleave for MQA workaround.
|
||||
k = repeat_kv(k, num_heads // num_kv_heads)
|
||||
v = repeat_kv(v, num_heads // num_kv_heads)
|
||||
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
|
Loading…
Reference in New Issue
Block a user