diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index 579c7dc2..f268e499 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -1401,11 +1401,10 @@ class FlashCausalLM(Model):
         if max_bs is None:
             input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
             position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
+            config = getattr(self.model, "config", None)
+            rope_scaling = getattr(config, "rope_scaling", None) if config else None
             if (  # mrope have position_ids per section, if so repeat n times
-                hasattr(self.model, "config")
-                and hasattr(self.model.config, "rope_scaling")
-                and "rope_type" in self.model.config.rope_scaling
-                and self.model.config.rope_scaling["rope_type"] == "mrope"
+                isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
             ):
                 n_sections = len(self.model.config.rope_scaling["mrope_section"])
                 position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)