mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: avoid copy for partial rotary embeddings
This commit is contained in:
parent
c49332adb6
commit
18f13a1b5f
@ -180,16 +180,11 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
# Apply partial rotary embedding and store the end of the embedding
|
# Apply partial positional embeddings in place
|
||||||
query_pass = query[:, :, self.rotary_emb_dim:]
|
self.rotary_emb(
|
||||||
key_pass = torch.select(kv, dim=1, index=0)[:, :, self.rotary_emb_dim:]
|
query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim],
|
||||||
|
cos, sin
|
||||||
# Apply in place positional rotary embeddings
|
)
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
# Restore the query and key from the partial rotary embedding
|
|
||||||
kv[:, 0, :, self.rotary_emb_dim:] = key_pass
|
|
||||||
query[:, :, self.rotary_emb_dim:] = query_pass
|
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
|
Loading…
Reference in New Issue
Block a user