diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0edea03a..dfb16621 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -143,12 +143,14 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + + if config.model_type != "llama4_text": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr(