mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fuse reshapes
This commit is contained in:
parent
8f28011e1e
commit
a2f437a291
@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -292,22 +294,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
k, v = kv.split(1, dim=2)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -325,22 +321,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
k, v = layer_past.split(1, dim=2)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
||||
-1, self.num_groups * self.num_heads, self.head_size
|
||||
)
|
||||
kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
|
Loading…
Reference in New Issue
Block a user