mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix the accuracy issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
a3967a57bc
commit
1a5ff1dc5f
@ -110,7 +110,7 @@ class Llama4TextExperts(nn.Module):
|
|||||||
self.intermediate_size = config.intermediate_size // weights.process_group.size()
|
self.intermediate_size = config.intermediate_size // weights.process_group.size()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.expert_dim = self.intermediate_size
|
self.expert_dim = self.intermediate_size
|
||||||
self.gate_up_proj = nn.Parameter(weights.get_sharded(f"{prefix}.gate_up_proj", dim=1), requires_grad=False)
|
self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False)
|
||||||
synchronize(weights.device)
|
synchronize(weights.device)
|
||||||
real_free_memory = get_free_memory(weights.device, 1)
|
real_free_memory = get_free_memory(weights.device, 1)
|
||||||
log_master(
|
log_master(
|
||||||
@ -119,7 +119,7 @@ class Llama4TextExperts(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=0), requires_grad=False)
|
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False)
|
||||||
synchronize(weights.device)
|
synchronize(weights.device)
|
||||||
real_free_memory = get_free_memory(weights.device, 1)
|
real_free_memory = get_free_memory(weights.device, 1)
|
||||||
log_master(
|
log_master(
|
||||||
@ -144,23 +144,23 @@ class Llama4TextExperts(nn.Module):
|
|||||||
torch.Tensor
|
torch.Tensor
|
||||||
"""
|
"""
|
||||||
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim)
|
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim)
|
||||||
if run_index != -1:
|
if run_index == 0:
|
||||||
torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt")
|
torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt")
|
||||||
|
|
||||||
|
|
||||||
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
|
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
|
||||||
if run_index != -1:
|
if run_index == 0:
|
||||||
torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt")
|
torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt")
|
||||||
|
|
||||||
|
|
||||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||||
if run_index != -1:
|
if run_index == 0:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt")
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt")
|
||||||
|
|
||||||
|
|
||||||
gate_up = torch.bmm(hidden_states, gate_up_proj)
|
gate_up = torch.bmm(hidden_states, gate_up_proj)
|
||||||
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
||||||
if run_index != -1:
|
if run_index == 0:
|
||||||
torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt")
|
torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt")
|
||||||
torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt")
|
torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt")
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ class Llama4TextExperts(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
next_states = next_states.view(-1, self.hidden_size)
|
next_states = next_states.view(-1, self.hidden_size)
|
||||||
if run_index != -1:
|
if run_index == 0:
|
||||||
torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt")
|
torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt")
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
@ -303,7 +303,7 @@ class Llama4TextMoe(nn.Module):
|
|||||||
# assert isinstance(self.experts, MoELayer)
|
# assert isinstance(self.experts, MoELayer)
|
||||||
|
|
||||||
|
|
||||||
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights)
|
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights, layer_idx=layer_idx)
|
||||||
synchronize(weights.device)
|
synchronize(weights.device)
|
||||||
real_free_memory = get_free_memory(weights.device, 1)
|
real_free_memory = get_free_memory(weights.device, 1)
|
||||||
log_master(
|
log_master(
|
||||||
@ -319,7 +319,7 @@ class Llama4TextMoe(nn.Module):
|
|||||||
logger.debug,
|
logger.debug,
|
||||||
f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB"
|
f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB"
|
||||||
)
|
)
|
||||||
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights)
|
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, layer_idx=layer_idx)
|
||||||
synchronize(weights.device)
|
synchronize(weights.device)
|
||||||
real_free_memory = get_free_memory(weights.device, 1)
|
real_free_memory = get_free_memory(weights.device, 1)
|
||||||
log_master(
|
log_master(
|
||||||
@ -334,8 +334,8 @@ class Llama4TextMoe(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||||
tokens_per_expert = hidden_states.shape[0]
|
tokens_per_expert = hidden_states.shape[0]
|
||||||
router_logits = self.router(hidden_states)
|
router_logits = self.router(hidden_states)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt")
|
# torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt")
|
||||||
|
|
||||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||||
router_scores = (
|
router_scores = (
|
||||||
@ -347,8 +347,8 @@ class Llama4TextMoe(nn.Module):
|
|||||||
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
|
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||||
)
|
)
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt")
|
# torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt")
|
||||||
|
|
||||||
|
|
||||||
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
||||||
@ -357,20 +357,20 @@ class Llama4TextMoe(nn.Module):
|
|||||||
dim=0,
|
dim=0,
|
||||||
index=router_indices,
|
index=router_indices,
|
||||||
).to(hidden_states.device)
|
).to(hidden_states.device)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt")
|
# torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt")
|
||||||
|
|
||||||
|
|
||||||
# we gather inputs corresponding to each expert based on the router indices
|
# we gather inputs corresponding to each expert based on the router indices
|
||||||
routed_in = routed_in * router_scores.reshape(-1, 1)
|
routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt")
|
# torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt")
|
||||||
routed_out = self.experts(routed_in, run_index)
|
routed_out = self.experts(routed_in, run_index)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt")
|
# torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt")
|
||||||
out = self.shared_expert(hidden_states, run_index, reduce=False)
|
out = self.shared_expert(hidden_states, run_index, reduce=False)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt")
|
# torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt")
|
||||||
# now that we finished expert computation -> we scatter add because we gathered previously
|
# now that we finished expert computation -> we scatter add because we gathered previously
|
||||||
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
||||||
# this scales a lot better if you do EP!
|
# this scales a lot better if you do EP!
|
||||||
@ -381,8 +381,8 @@ class Llama4TextMoe(nn.Module):
|
|||||||
# if self.process_group.size() > 1:
|
# if self.process_group.size() > 1:
|
||||||
# torch.distributed.all_reduce(out, group=self.process_group)
|
# torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt")
|
# torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt")
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -689,10 +689,10 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
|
# torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
|
||||||
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
|
# torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
|
||||||
torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
|
# torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
|
||||||
|
|
||||||
if self.use_rope: # the 16E model skips rope for long context on certain layers
|
if self.use_rope: # the 16E model skips rope for long context on certain layers
|
||||||
#self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
|
#self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
|
||||||
@ -703,18 +703,18 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
|
# torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
|
||||||
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
|
# torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
|
||||||
|
|
||||||
|
|
||||||
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
|
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
|
||||||
query_states = self.qk_norm(query_states)
|
query_states = self.qk_norm(query_states)
|
||||||
key_states = self.qk_norm(key_states)
|
key_states = self.qk_norm(key_states)
|
||||||
|
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
|
# torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
|
||||||
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
|
# torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
|
||||||
|
|
||||||
|
|
||||||
# query_states = query_states.view(-1, self.num_heads, self.head_size)
|
# query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||||
@ -742,9 +742,9 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
#seq_len = input_shape / bs
|
#seq_len = input_shape / bs
|
||||||
attn_scales = attn_scales.view(*input_shape, 1, 1)
|
attn_scales = attn_scales.view(*input_shape, 1, 1)
|
||||||
query_states = (query_states * attn_scales).to(query_states.dtype)
|
query_states = (query_states * attn_scales).to(query_states.dtype)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt")
|
# torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt")
|
||||||
torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt")
|
# torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt")
|
||||||
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
@ -806,8 +806,8 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt")
|
# torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt")
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -873,11 +873,11 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
run_index: int = 0,
|
run_index: int = 0,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
|
||||||
|
|
||||||
# use local attention mask for ROPE layers
|
# use local attention mask for ROPE layers
|
||||||
if self.use_chunked_attention and chunk_causal_mask is not None:
|
if self.use_chunked_attention and chunk_causal_mask is not None:
|
||||||
@ -896,24 +896,24 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
|
# torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
|
||||||
hidden_states = residual + attention_states
|
hidden_states = residual + attention_states
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
|
||||||
hidden_states = self.feed_forward(hidden_states, adapter_data, run_index)
|
hidden_states = self.feed_forward(hidden_states, adapter_data, run_index)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
|
||||||
hidden_states = residual + hidden_states.view(residual.shape)
|
hidden_states = residual + hidden_states.view(residual.shape)
|
||||||
if run_index != -1:
|
#if run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
|
||||||
#outputs = (hidden_states,)
|
#outputs = (hidden_states,)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
# if residual is None:
|
# if residual is None:
|
||||||
@ -988,8 +988,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
if self.run_index != -1:
|
#if self.run_index != -1:
|
||||||
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
# torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
||||||
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -1030,11 +1030,11 @@ class Llama4TextModel(nn.Module):
|
|||||||
run_index=self.run_index,
|
run_index=self.run_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.run_index != -1:
|
if self.run_index == 0:
|
||||||
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
||||||
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
if self.run_index != -1:
|
if self.run_index == 0:
|
||||||
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
||||||
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
||||||
self.run_index += 1
|
self.run_index += 1
|
||||||
@ -2014,4 +2014,4 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
attention_mask
|
attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -303,7 +303,7 @@ class Weights:
|
|||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
tensors = []
|
tensors_slices = []
|
||||||
block_offset = 0
|
block_offset = 0
|
||||||
for block_size in block_sizes:
|
for block_size in block_sizes:
|
||||||
assert (
|
assert (
|
||||||
@ -312,15 +312,19 @@ class Weights:
|
|||||||
shard_block_size = block_size // world_size
|
shard_block_size = block_size // world_size
|
||||||
start = rank * shard_block_size
|
start = rank * shard_block_size
|
||||||
stop = (rank + 1) * shard_block_size
|
stop = (rank + 1) * shard_block_size
|
||||||
if dim == 0:
|
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||||
tensor = slice_[block_offset + start : block_offset + stop]
|
|
||||||
elif dim == 1:
|
|
||||||
tensor = slice_[:, block_offset + start : block_offset + stop]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
|
|
||||||
tensors.append(tensor)
|
|
||||||
block_offset += block_size
|
block_offset += block_size
|
||||||
tensor = torch.cat(tensors, dim=dim)
|
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = slice_[tensors_slices, ...]
|
||||||
|
elif dim == 1 or dim == -2:
|
||||||
|
tensor = slice_[:, tensors_slices, ...]
|
||||||
|
elif dim == 2 or dim == -1:
|
||||||
|
tensor = slice_[..., tensors_slices]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||||
|
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
# Avoid casting quantizer dtypes.
|
# Avoid casting quantizer dtypes.
|
||||||
|
Loading…
Reference in New Issue
Block a user