diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 11864c52..0e3af85a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -48,7 +48,6 @@ from text_generation_server.layers.attention import ( ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaAttention, - LlamaMLP, ) @@ -444,7 +443,7 @@ class Llama4TextDecoderLayer(nn.Module): if self.is_moe_layer: # the 128E model interleaves dense / sparse self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights) else: - self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) + self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm",