fix normal att

This commit is contained in:
OlivierDehaene 2023-05-29 12:10:17 +02:00
parent 63a18c1414
commit 12ab24ae64

View File

@ -125,12 +125,12 @@ class FlashRWAttention(torch.nn.Module):
# Split query from key_value # Split query from key_value
query, kv = qkv.split( query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size], dim=1 [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1
) )
# Prepare query and key_value for indexing # Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, 1, self.head_size) kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
@ -141,7 +141,7 @@ class FlashRWAttention(torch.nn.Module):
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
# Expand to query shape # Expand to query shape
kv = kv.expand(-1, 2, query.shape[1], self.head_size) kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -168,7 +168,7 @@ class FlashRWAttention(torch.nn.Module):
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = layer_past.expand(-1, 2, query.shape[1], self.head_size) kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)