mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust positional embeddings for multi dimensional position ids
This commit is contained in:
parent
e1114c2726
commit
279b114ab3
@ -130,23 +130,18 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
_query = query.clone()
|
if self.mrope_section is not None:
|
||||||
_cos = cos.clone()
|
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
|
||||||
_sin = sin.clone()
|
cos = torch.cat(
|
||||||
|
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
sin = torch.cat(
|
||||||
|
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
_cos = torch.cat((_cos, _cos), dim=-1)
|
|
||||||
_sin = torch.cat((_sin, _sin), dim=-1)
|
|
||||||
q_emb = (_query * _cos).reshape(2, 1, -1) + (
|
|
||||||
rotate_half(_query) * _sin
|
|
||||||
).reshape(2, 1, -1)
|
|
||||||
k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + (
|
|
||||||
rotate_half(torch.select(kv, dim=1, index=0)) * _sin
|
|
||||||
).reshape(2, 1, -1)
|
|
||||||
|
|
||||||
query = q_emb.reshape(-1, self.num_heads, self.head_size)
|
|
||||||
kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size)
|
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
@ -330,10 +325,13 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# TODO: ensure we are getting the correct positional embeddings
|
# flatten position ids from 2D to 1D
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids[0, 0, :], true_max_s, hidden_states.dtype
|
position_ids.flatten(), true_max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
# reshape cos and sin for the number of position ids present in the input
|
||||||
|
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
|
||||||
|
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
Loading…
Reference in New Issue
Block a user