From 70217ac3454396d9a08a25ce1aa8b40a1fe87069 Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Thu, 29 May 2025 15:58:24 +0800 Subject: [PATCH] [Gaudi] Fix the OOM issue of Llama-4-Scout-17B-16E-Instruct (#3245) Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama_modeling.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0edea03a..dfb16621 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -143,12 +143,14 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + + if config.model_type != "llama4_text": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr(