mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
use native grouped attention
This commit is contained in:
parent
2d4b31070e
commit
f400f2d58b
@ -182,10 +182,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
if self.num_heads_kv == 1:
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
@ -313,13 +309,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
|
@ -271,9 +271,6 @@ class FlashMQAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
|
Loading…
Reference in New Issue
Block a user