hotfix: qwen2 formatting

This commit is contained in:
Daniël de Kok 2025-03-10 15:01:23 +00:00
parent c5ecc7a4de
commit d0cb06af4a

View File

@ -51,6 +51,7 @@ def load_attention(config, prefix, weights, layer_id):
process_group=weights.process_group, process_group=weights.process_group,
) )
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0 assert config.num_attention_heads % weights.process_group.size() == 0
@ -132,7 +133,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data adapter_data,
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
@ -244,7 +245,9 @@ class Qwen2MLP(nn.Module):
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class Qwen2Layer(nn.Module): class Qwen2Layer(nn.Module):
@ -254,7 +257,9 @@ class Qwen2Layer(nn.Module):
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights index=layer_id, prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id) self.mlp = Qwen2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
@ -293,7 +298,7 @@ class Qwen2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data adapter_data,
) )
hidden_states = attn_output + residual hidden_states = attn_output + residual