Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-05 10:08:29 +00:00
parent ccddbba752
commit dafc597a8b
2 changed files with 36 additions and 19 deletions

View File

@ -71,6 +71,11 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" _CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
_CONFIG_FOR_DOC = "Llama4Config" _CONFIG_FOR_DOC = "Llama4Config"
def torch_save(tensor, name):
# Only save on the main process (rank 0) when using distributed training
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
torch.save(tensor, name)
class Llama4TextExperts(nn.Module): class Llama4TextExperts(nn.Module):
def __init__(self, prefix, config: Llama4TextConfig, weights): def __init__(self, prefix, config: Llama4TextConfig, weights):
@ -233,10 +238,11 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
seq_len, hidden_dim = hidden_states.shape #seq_len, hidden_dim = hidden_states.shape
print(f"hidden_states.shape: {hidden_states.shape}")
hidden_states = hidden_states.view(-1, self.hidden_dim) hidden_states = hidden_states.view(-1, self.hidden_dim)
tokens_per_expert = hidden_states.shape[0]
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
tokens_per_expert = seq_len
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = ( router_scores = (
@ -536,18 +542,18 @@ class Llama4TextDecoderLayer(nn.Module):
else: else:
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) #self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) #self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
# self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
# prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
# weights=weights, weights=weights,
# eps=config.rms_norm_eps, eps=config.rms_norm_eps,
# ) )
# self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
# prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
# weights=weights, weights=weights,
# eps=config.rms_norm_eps, eps=config.rms_norm_eps,
# ) )
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -566,7 +572,7 @@ class Llama4TextDecoderLayer(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.input_layernorm(hidden_states)
attention_states = self.self_attn( attention_states = self.self_attn(
hidden_states, hidden_states,
@ -584,7 +590,7 @@ class Llama4TextDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = self.feed_forward(hidden_states, adapter_data)
hidden_states = residual + hidden_states.view(residual.shape) hidden_states = residual + hidden_states.view(residual.shape)
#outputs = (hidden_states,) #outputs = (hidden_states,)
@ -642,9 +648,9 @@ class Llama4TextModel(nn.Module):
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
self.run_index = 0
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
@ -660,6 +666,8 @@ class Llama4TextModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
@ -679,7 +687,12 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
hidden_states, _ = self.norm(hidden_states)
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
self.run_index += 1
return hidden_states return hidden_states

View File

@ -1520,6 +1520,10 @@ class FlashCausalLM(Model):
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended") logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch del _batch, batch
print(f"max_input_tokens: {max_input_tokens}")
print(f"max_total_tokens: {max_total_tokens}")
print(f"num_blocks: {num_blocks}")
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch) self.warmup_hpu_graph(batch)