mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
remove unnecessary code
This commit is contained in:
parent
1f37d57266
commit
52f593bba7
@ -91,20 +91,6 @@ except ImportError as e:
|
|||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :].expand(
|
|
||||||
total_tokens, num_key_value_heads, n_rep, head_dim
|
|
||||||
)
|
|
||||||
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -235,17 +221,6 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
|
|
||||||
# repeating.
|
|
||||||
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
|
|
||||||
|
|
||||||
# num_heads = q.shape[1]
|
|
||||||
# num_kv_heads = k.shape[1]
|
|
||||||
# if num_kv_heads != num_heads:
|
|
||||||
# # Interleave for MQA workaround.
|
|
||||||
# k = repeat_kv(k, num_heads // num_kv_heads)
|
|
||||||
# v = repeat_kv(v, num_heads // num_kv_heads)
|
|
||||||
|
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
Loading…
Reference in New Issue
Block a user