From cb7ec9cb60b1daa1c7e6478c7b79f3d048960c17 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 Jan 2025 13:03:36 +0000 Subject: [PATCH] fix: improve mrope check in cuda graph warmup --- server/text_generation_server/models/flash_causal_lm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 579c7dc2..f268e499 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1401,11 +1401,10 @@ class FlashCausalLM(Model): if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None if ( # mrope have position_ids per section, if so repeat n times - hasattr(self.model, "config") - and hasattr(self.model.config, "rope_scaling") - and "rope_type" in self.model.config.rope_scaling - and self.model.config.rope_scaling["rope_type"] == "mrope" + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" ): n_sections = len(self.model.config.rope_scaling["mrope_section"]) position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)