use native grouped attention

This commit is contained in:
OlivierDehaene 2023-07-18 09:21:22 +02:00
parent 2d4b31070e
commit f400f2d58b
2 changed files with 0 additions and 14 deletions

View File

@ -182,10 +182,6 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_2_cuda.varlen_fwd( flash_attn_2_cuda.varlen_fwd(
query, query,
@ -313,13 +309,6 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# flash attention # flash attention
flash_attn_2_cuda.varlen_fwd( flash_attn_2_cuda.varlen_fwd(
query, query,

View File

@ -271,9 +271,6 @@ class FlashMQAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_2_cuda.varlen_fwd( flash_attn_2_cuda.varlen_fwd(
query, query,