From fb104d8b427ce6d1c7681e5d05319a5c0ba83c8f Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 29 May 2025 06:38:45 +0000 Subject: [PATCH] Fix the OOM issue of Llama-4-Scout-17B-16E-Instruct Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama_modeling.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0edea03a..dfb16621 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -143,12 +143,14 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + + if config.model_type != "llama4_text": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr(