Debug accuracy issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-05 14:38:34 +00:00
parent dafc597a8b
commit 2dadceaf07

View File

@ -239,7 +239,6 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
#seq_len, hidden_dim = hidden_states.shape #seq_len, hidden_dim = hidden_states.shape
print(f"hidden_states.shape: {hidden_states.shape}")
hidden_states = hidden_states.view(-1, self.hidden_dim) hidden_states = hidden_states.view(-1, self.hidden_dim)
tokens_per_expert = hidden_states.shape[0] tokens_per_expert = hidden_states.shape[0]
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
@ -255,7 +254,7 @@ class Llama4TextMoe(nn.Module):
) )
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
routed_in = torch.gather( routed_in = torch.gather(
input=hidden_states, input=hidden_states,
dim=0, dim=0,
@ -268,7 +267,7 @@ class Llama4TextMoe(nn.Module):
# now that we finished expert computation -> we scatter add because we gathered previously # now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP! # this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim)) out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
return out return out
@ -455,51 +454,80 @@ class Llama4TextAttention(FlashLlamaAttention):
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
run_index,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
#hidden_shape = (*input_shape, -1, self.head_dim) #hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
query_states, kv_states = qkv.split( # query_states, kv_states = qkv.split(
# [
# self.head_size * self.num_heads,
# 2 * self.head_size * self.num_key_value_heads,
# ],
# dim=-1,
# )
query_states, key_states, value_states = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads, self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
], ],
dim=-1, dim=-1,
) )
query_states = query_states.view(-1, self.num_heads, self.head_size) query_states = query_states.view(-1, self.num_heads, self.head_size)
kv_states = kv_states.view(-1, 2, self.num_key_value_heads, self.head_size) key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
# query_states = self.q_proj(hidden_states).view(hidden_shape) # query_states = self.q_proj(hidden_states).view(hidden_shape)
# key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) # key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if self.use_rope: # the 16E model skips rope for long context on certain layers if self.use_rope: # the 16E model skips rope for long context on certain layers
self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
self.rotary_emb(query_states, key_states, cos, sin)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
query_states = self.qk_norm(query_states) query_states = self.qk_norm(query_states)
key_states = self.qk_norm(torch.select(kv_states, dim=1, index=0)) key_states = self.qk_norm(key_states)
if run_index != -1:
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
# query_states = query_states.transpose(1, 2) # query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2) # key_states = key_states.transpose(1, 2)
kv_cache.store( kv_cache.store(
key=kv_states[:, 0], key=key_states,
value=kv_states[:, 1], value=value_states,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
log_master(
logger.debug,
f"Prefill: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
)
# sdpa # sdpa
attn_output = attention( attn_output = attention(
query=query_states, query=query_states,
key=kv_states[:, 0], key=key_states,
value=kv_states[:, 1], value=value_states,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
@ -507,6 +535,10 @@ class Llama4TextAttention(FlashLlamaAttention):
) )
# Decode # Decode
else: else:
log_master(
logger.debug,
f"Decode: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
)
attn_output = paged_attention( attn_output = paged_attention(
query_states, query_states,
kv_cache, kv_cache,
@ -542,18 +574,18 @@ class Llama4TextDecoderLayer(nn.Module):
else: else:
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
#self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
#self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( # self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", # prefix=f"{prefix}.input_layernorm",
weights=weights, # weights=weights,
eps=config.rms_norm_eps, # eps=config.rms_norm_eps,
) # )
self.post_attention_layernorm = FastRMSNorm.load( # self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", # prefix=f"{prefix}.post_attention_layernorm",
weights=weights, # weights=weights,
eps=config.rms_norm_eps, # eps=config.rms_norm_eps,
) # )
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -570,9 +602,14 @@ class Llama4TextDecoderLayer(nn.Module):
seqlen, seqlen,
adapter_data, adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
run_index
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states) if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
hidden_states = self.input_layernorm(hidden_states)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
attention_states = self.self_attn( attention_states = self.self_attn(
hidden_states, hidden_states,
@ -583,16 +620,27 @@ class Llama4TextDecoderLayer(nn.Module):
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
run_index,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if run_index != -1:
torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
hidden_states = residual + attention_states hidden_states = residual + attention_states
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = self.feed_forward(hidden_states, adapter_data)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
hidden_states = residual + hidden_states.view(residual.shape) hidden_states = residual + hidden_states.view(residual.shape)
if run_index != -1:
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
#outputs = (hidden_states,) #outputs = (hidden_states,)
return hidden_states return hidden_states
# if residual is None: # if residual is None:
@ -648,7 +696,7 @@ class Llama4TextModel(nn.Module):
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
self.run_index = 0 self.run_index = -1
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config) #self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -666,7 +714,8 @@ class Llama4TextModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt") if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -685,12 +734,15 @@ class Llama4TextModel(nn.Module):
seqlen, seqlen,
adapter_data, adapter_data,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
run_index=self.run_index,
) )
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") if self.run_index != -1:
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
self.run_index += 1 self.run_index += 1
return hidden_states return hidden_states
@ -733,7 +785,7 @@ class Llama4ForCausalLM(nn.Module):
adapter_data=adapter_data, adapter_data=adapter_data,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
print(f"lm_head_indices={lm_head_indices}")
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
@ -1407,7 +1459,6 @@ class Llama4ForConditionalGeneration(nn.Module):
f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}"
) )
inputs_embeds = self.text_model.model.embed_tokens(input_ids) inputs_embeds = self.text_model.model.embed_tokens(input_ids)
print(f"LLama4 inputs_embeds shape: {inputs_embeds.shape}")
vision_feature_layer = ( vision_feature_layer = (
vision_feature_layer vision_feature_layer
if vision_feature_layer is not None if vision_feature_layer is not None