Debug accuracy issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-05 14:38:34 +00:00
parent dafc597a8b
commit 2dadceaf07

View File

@ -239,7 +239,6 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states, adapter_data):
#seq_len, hidden_dim = hidden_states.shape
print(f"hidden_states.shape: {hidden_states.shape}")
hidden_states = hidden_states.view(-1, self.hidden_dim)
tokens_per_expert = hidden_states.shape[0]
router_logits = self.router(hidden_states)
@ -255,7 +254,7 @@ class Llama4TextMoe(nn.Module):
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
routed_in = torch.gather(
input=hidden_states,
dim=0,
@ -268,7 +267,7 @@ class Llama4TextMoe(nn.Module):
# now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim))
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
return out
@ -455,51 +454,80 @@ class Llama4TextAttention(FlashLlamaAttention):
slots,
seqlen,
adapter_data,
run_index,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
#hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states, adapter_data)
query_states, kv_states = qkv.split(
# query_states, kv_states = qkv.split(
# [
# self.head_size * self.num_heads,
# 2 * self.head_size * self.num_key_value_heads,
# ],
# dim=-1,
# )
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=-1,
)
query_states = query_states.view(-1, self.num_heads, self.head_size)
kv_states = kv_states.view(-1, 2, self.num_key_value_heads, self.head_size)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
# query_states = self.q_proj(hidden_states).view(hidden_shape)
# key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if self.use_rope: # the 16E model skips rope for long context on certain layers
self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
#self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
self.rotary_emb(query_states, key_states, cos, sin)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
query_states = self.qk_norm(query_states)
key_states = self.qk_norm(torch.select(kv_states, dim=1, index=0))
key_states = self.qk_norm(key_states)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
kv_cache.store(
key=kv_states[:, 0],
value=kv_states[:, 1],
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
log_master(
logger.debug,
f"Prefill: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
)
# sdpa
attn_output = attention(
query=query_states,
key=kv_states[:, 0],
value=kv_states[:, 1],
key=key_states,
value=value_states,
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
@ -507,6 +535,10 @@ class Llama4TextAttention(FlashLlamaAttention):
)
# Decode
else:
log_master(
logger.debug,
f"Decode: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
)
attn_output = paged_attention(
query_states,
kv_cache,
@ -542,18 +574,18 @@ class Llama4TextDecoderLayer(nn.Module):
else:
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
#self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
#self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
# self.input_layernorm = FastRMSNorm.load(
# prefix=f"{prefix}.input_layernorm",
# weights=weights,
# eps=config.rms_norm_eps,
# )
# self.post_attention_layernorm = FastRMSNorm.load(
# prefix=f"{prefix}.post_attention_layernorm",
# weights=weights,
# eps=config.rms_norm_eps,
# )
self.layer_idx = layer_idx
@ -570,9 +602,14 @@ class Llama4TextDecoderLayer(nn.Module):
seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
run_index
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
hidden_states = self.input_layernorm(hidden_states)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
attention_states = self.self_attn(
hidden_states,
@ -583,16 +620,27 @@ class Llama4TextDecoderLayer(nn.Module):
slots,
seqlen,
adapter_data,
run_index,
hpu_attention_meta=hpu_attention_meta,
)
if run_index != -1:
torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
hidden_states = residual + attention_states
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
hidden_states = self.feed_forward(hidden_states, adapter_data)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
hidden_states = residual + hidden_states.view(residual.shape)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
#outputs = (hidden_states,)
return hidden_states
# if residual is None:
@ -648,7 +696,7 @@ class Llama4TextModel(nn.Module):
weights=weights,
eps=config.rms_norm_eps,
)
self.run_index = 0
self.run_index = -1
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
@ -666,7 +714,8 @@ class Llama4TextModel(nn.Module):
) -> torch.Tensor:
hidden_states = inputs_embeds
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -685,12 +734,15 @@ class Llama4TextModel(nn.Module):
seqlen,
adapter_data,
hpu_attention_meta=hpu_attention_meta,
run_index=self.run_index,
)
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
hidden_states, _ = self.norm(hidden_states)
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
self.run_index += 1
return hidden_states
@ -733,7 +785,7 @@ class Llama4ForCausalLM(nn.Module):
adapter_data=adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
print(f"lm_head_indices={lm_head_indices}")
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
@ -1407,7 +1459,6 @@ class Llama4ForConditionalGeneration(nn.Module):
f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}"
)
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
print(f"LLama4 inputs_embeds shape: {inputs_embeds.shape}")
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None