This commit is contained in:
OlivierDehaene 2023-05-30 17:09:51 +02:00
parent a2f437a291
commit c7b899a438

View File

@ -295,7 +295,11 @@ class FlashRWLargeAttention(torch.nn.Module):
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
# Expand to query shape # Expand to query shape
kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -322,7 +326,11 @@ class FlashRWLargeAttention(torch.nn.Module):
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)