Fix the accuracy issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-08 08:29:59 +00:00
parent a3967a57bc
commit 1a5ff1dc5f
2 changed files with 70 additions and 66 deletions

View File

@ -110,7 +110,7 @@ class Llama4TextExperts(nn.Module):
self.intermediate_size = config.intermediate_size // weights.process_group.size() self.intermediate_size = config.intermediate_size // weights.process_group.size()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(weights.get_sharded(f"{prefix}.gate_up_proj", dim=1), requires_grad=False) self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False)
synchronize(weights.device) synchronize(weights.device)
real_free_memory = get_free_memory(weights.device, 1) real_free_memory = get_free_memory(weights.device, 1)
log_master( log_master(
@ -119,7 +119,7 @@ class Llama4TextExperts(nn.Module):
) )
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=0), requires_grad=False) self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False)
synchronize(weights.device) synchronize(weights.device)
real_free_memory = get_free_memory(weights.device, 1) real_free_memory = get_free_memory(weights.device, 1)
log_master( log_master(
@ -144,23 +144,23 @@ class Llama4TextExperts(nn.Module):
torch.Tensor torch.Tensor
""" """
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim)
if run_index != -1: if run_index == 0:
torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt")
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
if run_index != -1: if run_index == 0:
torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt")
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
if run_index != -1: if run_index == 0:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt")
gate_up = torch.bmm(hidden_states, gate_up_proj) gate_up = torch.bmm(hidden_states, gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
if run_index != -1: if run_index == 0:
torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt")
torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt")
@ -169,7 +169,7 @@ class Llama4TextExperts(nn.Module):
next_states = next_states.view(-1, self.hidden_size) next_states = next_states.view(-1, self.hidden_size)
if run_index != -1: if run_index == 0:
torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt")
# Reduce sum # Reduce sum
@ -303,7 +303,7 @@ class Llama4TextMoe(nn.Module):
# assert isinstance(self.experts, MoELayer) # assert isinstance(self.experts, MoELayer)
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights, layer_idx=layer_idx)
synchronize(weights.device) synchronize(weights.device)
real_free_memory = get_free_memory(weights.device, 1) real_free_memory = get_free_memory(weights.device, 1)
log_master( log_master(
@ -319,7 +319,7 @@ class Llama4TextMoe(nn.Module):
logger.debug, logger.debug,
f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB"
) )
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, layer_idx=layer_idx)
synchronize(weights.device) synchronize(weights.device)
real_free_memory = get_free_memory(weights.device, 1) real_free_memory = get_free_memory(weights.device, 1)
log_master( log_master(
@ -334,8 +334,8 @@ class Llama4TextMoe(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_dim) hidden_states = hidden_states.view(-1, self.hidden_dim)
tokens_per_expert = hidden_states.shape[0] tokens_per_expert = hidden_states.shape[0]
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
if run_index != -1: #if run_index != -1:
torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt") # torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt")
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = ( router_scores = (
@ -347,8 +347,8 @@ class Llama4TextMoe(nn.Module):
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
) )
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
if run_index != -1: #if run_index != -1:
torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt") # torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt")
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
@ -357,20 +357,20 @@ class Llama4TextMoe(nn.Module):
dim=0, dim=0,
index=router_indices, index=router_indices,
).to(hidden_states.device) ).to(hidden_states.device)
if run_index != -1: #if run_index != -1:
torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt") # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt")
# we gather inputs corresponding to each expert based on the router indices # we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1) routed_in = routed_in * router_scores.reshape(-1, 1)
if run_index != -1: #if run_index != -1:
torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt") # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt")
routed_out = self.experts(routed_in, run_index) routed_out = self.experts(routed_in, run_index)
if run_index != -1: #if run_index != -1:
torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt") # torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt")
out = self.shared_expert(hidden_states, run_index, reduce=False) out = self.shared_expert(hidden_states, run_index, reduce=False)
if run_index != -1: #if run_index != -1:
torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt") # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt")
# now that we finished expert computation -> we scatter add because we gathered previously # now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP! # this scales a lot better if you do EP!
@ -381,8 +381,8 @@ class Llama4TextMoe(nn.Module):
# if self.process_group.size() > 1: # if self.process_group.size() > 1:
# torch.distributed.all_reduce(out, group=self.process_group) # torch.distributed.all_reduce(out, group=self.process_group)
if run_index != -1: #if run_index != -1:
torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt")
return out return out
@ -689,10 +689,10 @@ class Llama4TextAttention(FlashLlamaAttention):
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
if run_index != -1: #if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") # torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
if self.use_rope: # the 16E model skips rope for long context on certain layers if self.use_rope: # the 16E model skips rope for long context on certain layers
#self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
@ -703,18 +703,18 @@ class Llama4TextAttention(FlashLlamaAttention):
if run_index != -1: #if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt") # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
query_states = self.qk_norm(query_states) query_states = self.qk_norm(query_states)
key_states = self.qk_norm(key_states) key_states = self.qk_norm(key_states)
if run_index != -1: #if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt") # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
# query_states = query_states.view(-1, self.num_heads, self.head_size) # query_states = query_states.view(-1, self.num_heads, self.head_size)
@ -742,9 +742,9 @@ class Llama4TextAttention(FlashLlamaAttention):
#seq_len = input_shape / bs #seq_len = input_shape / bs
attn_scales = attn_scales.view(*input_shape, 1, 1) attn_scales = attn_scales.view(*input_shape, 1, 1)
query_states = (query_states * attn_scales).to(query_states.dtype) query_states = (query_states * attn_scales).to(query_states.dtype)
if run_index != -1: #if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt") # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt")
torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt") # torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt")
# Prefill # Prefill
@ -806,8 +806,8 @@ class Llama4TextAttention(FlashLlamaAttention):
) )
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
if run_index != -1: #if run_index != -1:
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt") # torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt")
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output return attn_output
@ -873,11 +873,11 @@ class Llama4TextDecoderLayer(nn.Module):
run_index: int = 0, run_index: int = 0,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
# use local attention mask for ROPE layers # use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None: if self.use_chunked_attention and chunk_causal_mask is not None:
@ -896,24 +896,24 @@ class Llama4TextDecoderLayer(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if run_index != -1: #if run_index != -1:
torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt") # torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
hidden_states = residual + attention_states hidden_states = residual + attention_states
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
hidden_states = self.feed_forward(hidden_states, adapter_data, run_index) hidden_states = self.feed_forward(hidden_states, adapter_data, run_index)
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
hidden_states = residual + hidden_states.view(residual.shape) hidden_states = residual + hidden_states.view(residual.shape)
if run_index != -1: #if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt") # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
#outputs = (hidden_states,) #outputs = (hidden_states,)
return hidden_states return hidden_states
# if residual is None: # if residual is None:
@ -988,8 +988,8 @@ class Llama4TextModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
if self.run_index != -1: #if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -1030,11 +1030,11 @@ class Llama4TextModel(nn.Module):
run_index=self.run_index, run_index=self.run_index,
) )
if self.run_index != -1: if self.run_index == 0:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
if self.run_index != -1: if self.run_index == 0:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
self.run_index += 1 self.run_index += 1
@ -2014,4 +2014,4 @@ class Llama4ForConditionalGeneration(nn.Module):
attention_mask attention_mask
) )
return logits, speculative_logits return logits, speculative_logits

View File

@ -303,7 +303,7 @@ class Weights:
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
tensors = [] tensors_slices = []
block_offset = 0 block_offset = 0
for block_size in block_sizes: for block_size in block_sizes:
assert ( assert (
@ -312,15 +312,19 @@ class Weights:
shard_block_size = block_size // world_size shard_block_size = block_size // world_size
start = rank * shard_block_size start = rank * shard_block_size
stop = (rank + 1) * shard_block_size stop = (rank + 1) * shard_block_size
if dim == 0: tensors_slices += range(block_offset + start, block_offset + stop)
tensor = slice_[block_offset + start : block_offset + stop]
elif dim == 1:
tensor = slice_[:, block_offset + start : block_offset + stop]
else:
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
tensors.append(tensor)
block_offset += block_size block_offset += block_size
tensor = torch.cat(tensors, dim=dim)
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes. # Avoid casting quantizer dtypes.