diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 5fe39bc9..f411c849 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,23 +130,18 @@ class Qwen2Attention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - _query = query.clone() - _cos = cos.clone() - _sin = sin.clone() + if self.mrope_section is not None: + # if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) - self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin) - - _cos = torch.cat((_cos, _cos), dim=-1) - _sin = torch.cat((_sin, _sin), dim=-1) - q_emb = (_query * _cos).reshape(2, 1, -1) + ( - rotate_half(_query) * _sin - ).reshape(2, 1, -1) - k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + ( - rotate_half(torch.select(kv, dim=1, index=0)) * _sin - ).reshape(2, 1, -1) - - query = q_emb.reshape(-1, self.num_heads, self.head_size) - kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -330,10 +325,13 @@ class Qwen2Model(torch.nn.Module): ) -> torch.Tensor: hidden_states = inputs_embeds - # TODO: ensure we are getting the correct positional embeddings + # flatten position ids from 2D to 1D cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids[0, 0, :], true_max_s, hidden_states.dtype + position_ids.flatten(), true_max_s, hidden_states.dtype ) + # reshape cos and sin for the number of position ids present in the input + cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers):