From 620769e380099f4f3f2fdf8630cfbb817bd6d28f Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 29 Oct 2024 17:49:50 +0000 Subject: [PATCH] fix: avoid qwen2 vl specific paths with qwen2 --- .../custom_modeling/flash_qwen2_modeling.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 8c2c31d6..cc4039b1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -61,7 +61,11 @@ class Qwen2Attention(torch.nn.Module): config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads - self.mrope_section = config.rope_scaling.get("mrope_section", None) + self.mrope_section = ( + config.rope_scaling.get("mrope_section", None) + if config.rope_scaling is not None + else None + ) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -322,9 +326,10 @@ class Qwen2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids.flatten(), true_max_s, hidden_states.dtype ) - # reshape cos and sin for the number of position ids present in the input - cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) - sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + # reshape back to 2D if the position_ids were 2D + if position_ids.size(0) != cos.size(0): + cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers): @@ -365,7 +370,8 @@ class Qwen2ForCausalLM(torch.nn.Module): ) self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, ) self.max_past = config.sliding_window