diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 236f851e..d166e112 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -71,6 +71,11 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE _CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" _CONFIG_FOR_DOC = "Llama4Config" +def torch_save(tensor, name): + # Only save on the main process (rank 0) when using distributed training + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + torch.save(tensor, name) + class Llama4TextExperts(nn.Module): def __init__(self, prefix, config: Llama4TextConfig, weights): @@ -233,10 +238,11 @@ class Llama4TextMoe(nn.Module): def forward(self, hidden_states, adapter_data): - seq_len, hidden_dim = hidden_states.shape + #seq_len, hidden_dim = hidden_states.shape + print(f"hidden_states.shape: {hidden_states.shape}") hidden_states = hidden_states.view(-1, self.hidden_dim) + tokens_per_expert = hidden_states.shape[0] router_logits = self.router(hidden_states) - tokens_per_expert = seq_len router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( @@ -536,18 +542,18 @@ class Llama4TextDecoderLayer(nn.Module): else: self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) - self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) - self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) - # self.input_layernorm = FastRMSNorm.load( - # prefix=f"{prefix}.input_layernorm", - # weights=weights, - # eps=config.rms_norm_eps, - # ) - # self.post_attention_layernorm = FastRMSNorm.load( - # prefix=f"{prefix}.post_attention_layernorm", - # weights=weights, - # eps=config.rms_norm_eps, - # ) + #self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) + #self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) self.layer_idx = layer_idx @@ -566,7 +572,7 @@ class Llama4TextDecoderLayer(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.input_layernorm(hidden_states) attention_states = self.self_attn( hidden_states, @@ -584,7 +590,7 @@ class Llama4TextDecoderLayer(nn.Module): # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.post_attention_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = residual + hidden_states.view(residual.shape) #outputs = (hidden_states,) @@ -642,9 +648,9 @@ class Llama4TextModel(nn.Module): weights=weights, eps=config.rms_norm_eps, ) + self.run_index = 0 - - self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + #self.rotary_emb = Llama4TextRotaryEmbedding(config=config) self.gradient_checkpointing = False def forward( @@ -660,6 +666,8 @@ class Llama4TextModel(nn.Module): ) -> torch.Tensor: hidden_states = inputs_embeds + torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt") + log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) @@ -679,7 +687,12 @@ class Llama4TextModel(nn.Module): hpu_attention_meta=hpu_attention_meta, ) - hidden_states, _ = self.norm(hidden_states, residual) + torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") + log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") + hidden_states, _ = self.norm(hidden_states) + torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") + log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") + self.run_index += 1 return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 5503efe4..1f55c27e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1520,6 +1520,10 @@ class FlashCausalLM(Model): if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("skip warmup hpu graph, not recommmended") del _batch, batch + print(f"max_input_tokens: {max_input_tokens}") + print(f"max_total_tokens: {max_total_tokens}") + print(f"num_blocks: {num_blocks}") + print(f"BLOCK_SIZE: {BLOCK_SIZE}") return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens self.warmup_hpu_graph(batch)