Fix experts issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-08 03:12:22 +00:00
parent 2dadceaf07
commit a3967a57bc
2 changed files with 706 additions and 196 deletions

View File

@ -61,6 +61,11 @@ from text_generation_server.utils.weights import (
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
def torch_save(tensor, name):
# Only save on the main process (rank 0) when using distributed training
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
torch.save(tensor, name)
def load_attention(config, prefix: str, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
@ -377,7 +382,7 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
self.index = index
with no_fp8(weights): with no_fp8(weights):
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
@ -438,6 +443,7 @@ class FlashLlamaLayer(nn.Module):
seqlen, seqlen,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
run_index,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -454,6 +460,10 @@ class FlashLlamaLayer(nn.Module):
adapter_data, adapter_data,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if run_index != -1:
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
if self.residual_multiplier is not None: if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier attn_output *= self.residual_multiplier
@ -462,6 +472,10 @@ class FlashLlamaLayer(nn.Module):
) )
mlp_output = self.mlp(normed_attn_res_output, adapter_data) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if run_index != -1:
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
if self.residual_multiplier is not None: if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier mlp_output *= self.residual_multiplier
@ -471,7 +485,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.run_index = -1
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -568,11 +582,12 @@ class FlashLlamaModel(torch.nn.Module):
seqlen, seqlen,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
self.run_index,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
self.run_index += 1
return hidden_states return hidden_states