mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-21 16:40:20 +00:00
fix: improve mrope check in cuda graph warmup
This commit is contained in:
parent
585e270ac3
commit
cb7ec9cb60
@ -1401,11 +1401,10 @@ class FlashCausalLM(Model):
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
config = getattr(self.model, "config", None)
|
||||
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
||||
if ( # mrope have position_ids per section, if so repeat n times
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "rope_scaling")
|
||||
and "rope_type" in self.model.config.rope_scaling
|
||||
and self.model.config.rope_scaling["rope_type"] == "mrope"
|
||||
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
||||
):
|
||||
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||
|
Loading…
Reference in New Issue
Block a user