Fix experts issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-08 03:12:22 +00:00
parent 2dadceaf07
commit a3967a57bc
2 changed files with 706 additions and 196 deletions

View File

@ -61,6 +61,11 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
def torch_save(tensor, name):
# Only save on the main process (rank 0) when using distributed training
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
torch.save(tensor, name)
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
@ -377,7 +382,7 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.index = index
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
@ -438,6 +443,7 @@ class FlashLlamaLayer(nn.Module):
seqlen,
adapter_data,
cross_attention_states,
run_index,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -454,6 +460,10 @@ class FlashLlamaLayer(nn.Module):
adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
if run_index != -1:
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
@ -462,6 +472,10 @@ class FlashLlamaLayer(nn.Module):
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if run_index != -1:
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
@ -471,7 +485,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.run_index = -1
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
@ -568,11 +582,12 @@ class FlashLlamaModel(torch.nn.Module):
seqlen,
adapter_data,
cross_attention_states,
self.run_index,
hpu_attention_meta=hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
self.run_index += 1
return hidden_states