[Gaudi] Fix the OOM issue of Llama-4-Scout-17B-16E-Instruct (#3245)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-05-29 15:58:24 +08:00 committed by GitHub
parent f14044009a
commit 70217ac345
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -143,12 +143,14 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
if config.model_type != "llama4_text":
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(