mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Debug accuracy issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
dafc597a8b
commit
2dadceaf07
@ -239,7 +239,6 @@ class Llama4TextMoe(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
#seq_len, hidden_dim = hidden_states.shape
|
#seq_len, hidden_dim = hidden_states.shape
|
||||||
print(f"hidden_states.shape: {hidden_states.shape}")
|
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||||
tokens_per_expert = hidden_states.shape[0]
|
tokens_per_expert = hidden_states.shape[0]
|
||||||
router_logits = self.router(hidden_states)
|
router_logits = self.router(hidden_states)
|
||||||
@ -255,7 +254,7 @@ class Llama4TextMoe(nn.Module):
|
|||||||
)
|
)
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
||||||
|
|
||||||
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
|
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
||||||
routed_in = torch.gather(
|
routed_in = torch.gather(
|
||||||
input=hidden_states,
|
input=hidden_states,
|
||||||
dim=0,
|
dim=0,
|
||||||
@ -268,7 +267,7 @@ class Llama4TextMoe(nn.Module):
|
|||||||
# now that we finished expert computation -> we scatter add because we gathered previously
|
# now that we finished expert computation -> we scatter add because we gathered previously
|
||||||
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
||||||
# this scales a lot better if you do EP!
|
# this scales a lot better if you do EP!
|
||||||
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim))
|
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -455,51 +454,80 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
run_index,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
#hidden_shape = (*input_shape, -1, self.head_dim)
|
#hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query_states, kv_states = qkv.split(
|
# query_states, kv_states = qkv.split(
|
||||||
|
# [
|
||||||
|
# self.head_size * self.num_heads,
|
||||||
|
# 2 * self.head_size * self.num_key_value_heads,
|
||||||
|
# ],
|
||||||
|
# dim=-1,
|
||||||
|
# )
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
2 * self.head_size * self.num_key_value_heads,
|
self.head_size * self.num_key_value_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||||
kv_states = kv_states.view(-1, 2, self.num_key_value_heads, self.head_size)
|
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt")
|
||||||
|
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt")
|
||||||
|
torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt")
|
||||||
|
|
||||||
# query_states = self.q_proj(hidden_states).view(hidden_shape)
|
# query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||||
# key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
|
# key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
|
||||||
# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
if self.use_rope: # the 16E model skips rope for long context on certain layers
|
if self.use_rope: # the 16E model skips rope for long context on certain layers
|
||||||
self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
|
#self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin)
|
||||||
|
self.rotary_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt")
|
||||||
|
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt")
|
||||||
|
|
||||||
|
|
||||||
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
|
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
|
||||||
query_states = self.qk_norm(query_states)
|
query_states = self.qk_norm(query_states)
|
||||||
key_states = self.qk_norm(torch.select(kv_states, dim=1, index=0))
|
key_states = self.qk_norm(key_states)
|
||||||
|
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt")
|
||||||
|
torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# query_states = query_states.transpose(1, 2)
|
# query_states = query_states.transpose(1, 2)
|
||||||
# key_states = key_states.transpose(1, 2)
|
# key_states = key_states.transpose(1, 2)
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=kv_states[:, 0],
|
key=key_states,
|
||||||
value=kv_states[:, 1],
|
value=value_states,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
log_master(
|
||||||
|
logger.debug,
|
||||||
|
f"Prefill: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
|
||||||
|
)
|
||||||
# sdpa
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query_states,
|
query=query_states,
|
||||||
key=kv_states[:, 0],
|
key=key_states,
|
||||||
value=kv_states[:, 1],
|
value=value_states,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
@ -507,6 +535,10 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
log_master(
|
||||||
|
logger.debug,
|
||||||
|
f"Decode: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}"
|
||||||
|
)
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query_states,
|
query_states,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
@ -542,18 +574,18 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
|
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
|
||||||
|
|
||||||
#self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
||||||
#self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
# self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
# prefix=f"{prefix}.input_layernorm",
|
||||||
weights=weights,
|
# weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
# eps=config.rms_norm_eps,
|
||||||
)
|
# )
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
# self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
# prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
# weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
# eps=config.rms_norm_eps,
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -570,9 +602,14 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
run_index
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, _ = self.input_layernorm(hidden_states)
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt")
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt")
|
||||||
|
|
||||||
attention_states = self.self_attn(
|
attention_states = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -583,16 +620,27 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
run_index,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt")
|
||||||
hidden_states = residual + attention_states
|
hidden_states = residual + attention_states
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt")
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt")
|
||||||
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt")
|
||||||
hidden_states = residual + hidden_states.view(residual.shape)
|
hidden_states = residual + hidden_states.view(residual.shape)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt")
|
||||||
#outputs = (hidden_states,)
|
#outputs = (hidden_states,)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
# if residual is None:
|
# if residual is None:
|
||||||
@ -648,7 +696,7 @@ class Llama4TextModel(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
self.run_index = 0
|
self.run_index = -1
|
||||||
|
|
||||||
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -666,7 +714,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
if self.run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
||||||
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -685,12 +734,15 @@ class Llama4TextModel(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
run_index=self.run_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
if self.run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
||||||
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
if self.run_index != -1:
|
||||||
|
torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
||||||
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
||||||
self.run_index += 1
|
self.run_index += 1
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -733,7 +785,7 @@ class Llama4ForCausalLM(nn.Module):
|
|||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
print(f"lm_head_indices={lm_head_indices}")
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
@ -1407,7 +1459,6 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}"
|
f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}"
|
||||||
)
|
)
|
||||||
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
||||||
print(f"LLama4 inputs_embeds shape: {inputs_embeds.shape}")
|
|
||||||
vision_feature_layer = (
|
vision_feature_layer = (
|
||||||
vision_feature_layer
|
vision_feature_layer
|
||||||
if vision_feature_layer is not None
|
if vision_feature_layer is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user