mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust positional embeddings for multi dimensional position ids
This commit is contained in:
parent
e1114c2726
commit
279b114ab3
@ -130,23 +130,18 @@ class Qwen2Attention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
_query = query.clone()
|
||||
_cos = cos.clone()
|
||||
_sin = sin.clone()
|
||||
if self.mrope_section is not None:
|
||||
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
|
||||
cos = torch.cat(
|
||||
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
|
||||
dim=-1,
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
_cos = torch.cat((_cos, _cos), dim=-1)
|
||||
_sin = torch.cat((_sin, _sin), dim=-1)
|
||||
q_emb = (_query * _cos).reshape(2, 1, -1) + (
|
||||
rotate_half(_query) * _sin
|
||||
).reshape(2, 1, -1)
|
||||
k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + (
|
||||
rotate_half(torch.select(kv, dim=1, index=0)) * _sin
|
||||
).reshape(2, 1, -1)
|
||||
|
||||
query = q_emb.reshape(-1, self.num_heads, self.head_size)
|
||||
kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size)
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
@ -330,10 +325,13 @@ class Qwen2Model(torch.nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# TODO: ensure we are getting the correct positional embeddings
|
||||
# flatten position ids from 2D to 1D
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids[0, 0, :], true_max_s, hidden_states.dtype
|
||||
position_ids.flatten(), true_max_s, hidden_states.dtype
|
||||
)
|
||||
# reshape cos and sin for the number of position ids present in the input
|
||||
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
|
||||
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
Loading…
Reference in New Issue
Block a user