fuse reshapes

This commit is contained in:
OlivierDehaene 2023-05-30 17:09:34 +02:00
parent 8f28011e1e
commit a2f437a291

View File

@ -1,3 +1,5 @@
import os
import torch
import torch.distributed
@ -292,22 +294,16 @@ class FlashRWLargeAttention(torch.nn.Module):
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
k, v = kv.split(1, dim=2)
# Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
kv[:, :, 0],
kv[:, :, 1],
attn_output,
cu_seqlens,
cu_seqlens,
@ -325,22 +321,16 @@ class FlashRWLargeAttention(torch.nn.Module):
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=2)
# Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
k,
v,
kv[:, :, 0],
kv[:, :, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,