mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
black
This commit is contained in:
parent
a2f437a291
commit
c7b899a438
@ -295,7 +295,11 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
kv = (
|
||||||
|
kv.unsqueeze(2)
|
||||||
|
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||||
|
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -322,7 +326,11 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
kv = (
|
||||||
|
layer_past.unsqueeze(2)
|
||||||
|
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||||
|
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
Loading…
Reference in New Issue
Block a user