mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix normal att
This commit is contained in:
parent
63a18c1414
commit
12ab24ae64
@ -125,12 +125,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Split query from key_value
|
# Split query from key_value
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[self.head_size * self.num_heads, 2 * self.head_size], dim=1
|
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare query and key_value for indexing
|
# Prepare query and key_value for indexing
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, 1, self.head_size)
|
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
@ -141,7 +141,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = kv.expand(-1, 2, query.shape[1], self.head_size)
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -168,7 +168,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = layer_past.expand(-1, 2, query.shape[1], self.head_size)
|
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
Loading…
Reference in New Issue
Block a user