From b1b79bf32d654061276b6e8026eee9e3ab6ea0a1 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 29 May 2025 08:37:25 +0000 Subject: [PATCH] Fix the Llama-4-Maverick-17B-128E crash issue Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama4_modeling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 11864c52..0e3af85a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -48,7 +48,6 @@ from text_generation_server.layers.attention import ( ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaAttention, - LlamaMLP, ) @@ -444,7 +443,7 @@ class Llama4TextDecoderLayer(nn.Module): if self.is_moe_layer: # the 128E model interleaves dense / sparse self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights) else: - self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) + self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm",