fix: improve mrope check in cuda graph warmup

This commit is contained in:
drbh 2025-01-29 13:03:36 +00:00
parent 585e270ac3
commit cb7ec9cb60

View File

@ -1401,11 +1401,10 @@ class FlashCausalLM(Model):
if max_bs is None: if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
config = getattr(self.model, "config", None)
rope_scaling = getattr(config, "rope_scaling", None) if config else None
if ( # mrope have position_ids per section, if so repeat n times if ( # mrope have position_ids per section, if so repeat n times
hasattr(self.model, "config") isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
and hasattr(self.model.config, "rope_scaling")
and "rope_type" in self.model.config.rope_scaling
and self.model.config.rope_scaling["rope_type"] == "mrope"
): ):
n_sections = len(self.model.config.rope_scaling["mrope_section"]) n_sections = len(self.model.config.rope_scaling["mrope_section"])
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)