mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fuse reshapes
This commit is contained in:
parent
8f28011e1e
commit
a2f437a291
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -292,22 +294,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
k, v = kv.split(1, dim=2)
|
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
-1, self.num_groups * self.num_heads, self.head_size
|
|
||||||
)
|
|
||||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
|
||||||
-1, self.num_groups * self.num_heads, self.head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
k,
|
kv[:, :, 0],
|
||||||
v,
|
kv[:, :, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -325,22 +321,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
k, v = layer_past.split(1, dim=2)
|
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
-1, self.num_groups * self.num_heads, self.head_size
|
|
||||||
)
|
|
||||||
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(
|
|
||||||
-1, self.num_groups * self.num_heads, self.head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
k,
|
kv[:, :, 0],
|
||||||
v,
|
kv[:, :, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
Loading…
Reference in New Issue
Block a user