From 79a2c956dec6a8126630e817aa529389f2c75471 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 28 Jan 2025 21:08:58 +0000 Subject: [PATCH] fix: improve position id init during cuda warmup for mrope and simplfy rotary forward --- .../text_generation_server/layers/rotary.py | 19 +++++++++---------- .../models/flash_causal_lm.py | 10 +++++----- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index c0baaf59..8132478c 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -579,15 +579,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # prepare input tensors - q, k = [x.transpose(0, 1) for x in (query, key)] - rotary_dim = cos.shape[-1] - q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] - q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) - k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) + # rotate half the sequence length + rot = cos.shape[-1] // 2 + q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1) + k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) + # apply the rotation + rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) + rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) def _update_cos_sin_cache(self, dtype, device, seqlen): # always cache the cos/sin for the full sequence length to avoid @@ -628,6 +627,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) # double the size and add a batch dimension - cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0) - sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0) + cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1) + sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1) return cos, sin diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c5d80bc5..600ed716 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1400,11 +1400,11 @@ class FlashCausalLM(Model): cache_lengths = [0] * bs if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - if hasattr(self.model, "get_position_ids"): - # use model specific position ids for initialization - position_ids = self.model.get_position_ids(input_ids) - else: - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + # mrope have position_ids per section, if so repeat n times + if self.model.config.rope_scaling["rope_type"] == "mrope": + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s