diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 47d372ad..579c7dc2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1404,6 +1404,7 @@ class FlashCausalLM(Model): if ( # mrope have position_ids per section, if so repeat n times hasattr(self.model, "config") and hasattr(self.model.config, "rope_scaling") + and "rope_type" in self.model.config.rope_scaling and self.model.config.rope_scaling["rope_type"] == "mrope" ): n_sections = len(self.model.config.rope_scaling["mrope_section"])