diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index e400fe47..32adc6de 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -91,20 +91,6 @@ except ImportError as e: HAS_FLASH_ATTN = True -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - total_tokens, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :].expand( - total_tokens, num_key_value_heads, n_rep, head_dim - ) - return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim) - - def attention( q, k, @@ -235,17 +221,6 @@ def attention( None, ) elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: - # NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not - # repeating. - # TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like - - # num_heads = q.shape[1] - # num_kv_heads = k.shape[1] - # if num_kv_heads != num_heads: - # # Interleave for MQA workaround. - # k = repeat_kv(k, num_heads // num_kv_heads) - # v = repeat_kv(v, num_heads // num_kv_heads) - output, _ = triton_attention( q, k,