fuse reshapes

This commit is contained in:
OlivierDehaene 2023-05-30 17:09:34 +02:00
parent 8f28011e1e
commit a2f437a291

View File

@ -1,3 +1,5 @@
import os
import torch import torch
import torch.distributed import torch.distributed
@ -292,22 +294,16 @@ class FlashRWLargeAttention(torch.nn.Module):
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
k, v = kv.split(1, dim=2)
# Expand to query shape # Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
k, kv[:, :, 0],
v, kv[:, :, 1],
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -325,22 +321,16 @@ class FlashRWLargeAttention(torch.nn.Module):
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
k, v = layer_past.split(1, dim=2)
# Expand to query shape # Expand to query shape
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
-1, self.num_groups * self.num_heads, self.head_size
)
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
-1, self.num_groups * self.num_heads, self.head_size
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
k, kv[:, :, 0],
v, kv[:, :, 1],
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,