fix: avoid qwen2 vl specific paths with qwen2

This commit is contained in:
David Holtz 2024-10-29 17:49:50 +00:00
parent 77eb07f73b
commit 620769e380

View File

@ -61,7 +61,11 @@ class Qwen2Attention(torch.nn.Module):
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = config.rope_scaling.get("mrope_section", None)
self.mrope_section = (
config.rope_scaling.get("mrope_section", None)
if config.rope_scaling is not None
else None
)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -322,9 +326,10 @@ class Qwen2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids.flatten(), true_max_s, hidden_states.dtype
)
# reshape cos and sin for the number of position ids present in the input
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
# reshape back to 2D if the position_ids were 2D
if position_ids.size(0) != cos.size(0):
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
residual = None
for i, layer in enumerate(self.layers):
@ -365,7 +370,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window