mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 22:32:07 +00:00
[gaudi] Fix the Llama-4-Maverick-17B-128E crash issue (#3246)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
70217ac345
commit
6b6e30a6f6
@ -48,7 +48,6 @@ from text_generation_server.layers.attention import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaAttention,
|
FlashLlamaAttention,
|
||||||
LlamaMLP,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -444,7 +443,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||||
self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
|
self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
|
||||||
else:
|
else:
|
||||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
|
self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
Loading…
Reference in New Issue
Block a user