fix: adjust positional embeddings for multi dimensional position ids

This commit is contained in:
David Holtz 2024-10-28 14:06:18 +00:00 committed by drbh
parent e1114c2726
commit 279b114ab3

View File

@ -130,23 +130,18 @@ class Qwen2Attention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
_query = query.clone()
_cos = cos.clone()
_sin = sin.clone()
if self.mrope_section is not None:
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin)
_cos = torch.cat((_cos, _cos), dim=-1)
_sin = torch.cat((_sin, _sin), dim=-1)
q_emb = (_query * _cos).reshape(2, 1, -1) + (
rotate_half(_query) * _sin
).reshape(2, 1, -1)
k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + (
rotate_half(torch.select(kv, dim=1, index=0)) * _sin
).reshape(2, 1, -1)
query = q_emb.reshape(-1, self.num_heads, self.head_size)
kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
@ -330,10 +325,13 @@ class Qwen2Model(torch.nn.Module):
) -> torch.Tensor:
hidden_states = inputs_embeds
# TODO: ensure we are getting the correct positional embeddings
# flatten position ids from 2D to 1D
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids[0, 0, :], true_max_s, hidden_states.dtype
position_ids.flatten(), true_max_s, hidden_states.dtype
)
# reshape cos and sin for the number of position ids present in the input
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
residual = None
for i, layer in enumerate(self.layers):