From 2dadceaf07d30233844bce5e31f9a95d7d7c5f6f Mon Sep 17 00:00:00 2001 From: yuanwu Date: Mon, 5 May 2025 14:38:34 +0000 Subject: [PATCH] Debug accuracy issue Signed-off-by: yuanwu --- .../custom_modeling/flash_llama4_modeling.py | 115 +++++++++++++----- 1 file changed, 83 insertions(+), 32 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index d166e112..2a74d7e5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -239,7 +239,6 @@ class Llama4TextMoe(nn.Module): def forward(self, hidden_states, adapter_data): #seq_len, hidden_dim = hidden_states.shape - print(f"hidden_states.shape: {hidden_states.shape}") hidden_states = hidden_states.view(-1, self.hidden_dim) tokens_per_expert = hidden_states.shape[0] router_logits = self.router(hidden_states) @@ -255,7 +254,7 @@ class Llama4TextMoe(nn.Module): ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) + router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) routed_in = torch.gather( input=hidden_states, dim=0, @@ -268,7 +267,7 @@ class Llama4TextMoe(nn.Module): # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! - out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim)) + out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)) return out @@ -455,51 +454,80 @@ class Llama4TextAttention(FlashLlamaAttention): slots, seqlen, adapter_data, + run_index, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] #hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.query_key_value(hidden_states, adapter_data) - query_states, kv_states = qkv.split( + # query_states, kv_states = qkv.split( + # [ + # self.head_size * self.num_heads, + # 2 * self.head_size * self.num_key_value_heads, + # ], + # dim=-1, + # ) + query_states, key_states, value_states = qkv.split( [ self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, ], dim=-1, ) query_states = query_states.view(-1, self.num_heads, self.head_size) - kv_states = kv_states.view(-1, 2, self.num_key_value_heads, self.head_size) + key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + if run_index != -1: + torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") + torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") + torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") # query_states = self.q_proj(hidden_states).view(hidden_shape) # key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) # value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if self.use_rope: # the 16E model skips rope for long context on certain layers - self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) + #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) + self.rotary_emb(query_states, key_states, cos, sin) + + if run_index != -1: + torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") + torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt") + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm query_states = self.qk_norm(query_states) - key_states = self.qk_norm(torch.select(kv_states, dim=1, index=0)) + key_states = self.qk_norm(key_states) + + if run_index != -1: + torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt") + torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") + # query_states = query_states.transpose(1, 2) # key_states = key_states.transpose(1, 2) kv_cache.store( - key=kv_states[:, 0], - value=kv_states[:, 1], + key=key_states, + value=value_states, slots=slots, kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: + log_master( + logger.debug, + f"Prefill: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}" + ) # sdpa attn_output = attention( query=query_states, - key=kv_states[:, 0], - value=kv_states[:, 1], + key=key_states, + value=value_states, kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, @@ -507,6 +535,10 @@ class Llama4TextAttention(FlashLlamaAttention): ) # Decode else: + log_master( + logger.debug, + f"Decode: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}" + ) attn_output = paged_attention( query_states, kv_cache, @@ -542,18 +574,18 @@ class Llama4TextDecoderLayer(nn.Module): else: self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) - #self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) - #self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.post_attention_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) + self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) + self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) + # self.input_layernorm = FastRMSNorm.load( + # prefix=f"{prefix}.input_layernorm", + # weights=weights, + # eps=config.rms_norm_eps, + # ) + # self.post_attention_layernorm = FastRMSNorm.load( + # prefix=f"{prefix}.post_attention_layernorm", + # weights=weights, + # eps=config.rms_norm_eps, + # ) self.layer_idx = layer_idx @@ -570,9 +602,14 @@ class Llama4TextDecoderLayer(nn.Module): seqlen, adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + run_index ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states, _ = self.input_layernorm(hidden_states) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt") + hidden_states = self.input_layernorm(hidden_states) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") attention_states = self.self_attn( hidden_states, @@ -583,16 +620,27 @@ class Llama4TextDecoderLayer(nn.Module): slots, seqlen, adapter_data, + run_index, hpu_attention_meta=hpu_attention_meta, ) + if run_index != -1: + torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt") hidden_states = residual + attention_states + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt") # Fully Connected residual = hidden_states - hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") hidden_states = self.feed_forward(hidden_states, adapter_data) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") hidden_states = residual + hidden_states.view(residual.shape) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt") #outputs = (hidden_states,) return hidden_states # if residual is None: @@ -648,7 +696,7 @@ class Llama4TextModel(nn.Module): weights=weights, eps=config.rms_norm_eps, ) - self.run_index = 0 + self.run_index = -1 #self.rotary_emb = Llama4TextRotaryEmbedding(config=config) self.gradient_checkpointing = False @@ -666,7 +714,8 @@ class Llama4TextModel(nn.Module): ) -> torch.Tensor: hidden_states = inputs_embeds - torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt") + if self.run_index != -1: + torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -685,12 +734,15 @@ class Llama4TextModel(nn.Module): seqlen, adapter_data, hpu_attention_meta=hpu_attention_meta, + run_index=self.run_index, ) - torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") + if self.run_index != -1: + torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") hidden_states, _ = self.norm(hidden_states) - torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") + if self.run_index != -1: + torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") self.run_index += 1 return hidden_states @@ -733,7 +785,7 @@ class Llama4ForCausalLM(nn.Module): adapter_data=adapter_data, hpu_attention_meta=hpu_attention_meta, ) - + print(f"lm_head_indices={lm_head_indices}") if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] @@ -1407,7 +1459,6 @@ class Llama4ForConditionalGeneration(nn.Module): f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" ) inputs_embeds = self.text_model.model.embed_tokens(input_ids) - print(f"LLama4 inputs_embeds shape: {inputs_embeds.shape}") vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None