fix: improve position id init during cuda warmup for mrope and simplfy rotary forward

This commit is contained in:
drbh 2025-01-28 21:08:58 +00:00
parent c75c01e9b9
commit 79a2c956de
2 changed files with 14 additions and 15 deletions

View File

@ -579,15 +579,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
cos: torch.Tensor,
sin: torch.Tensor,
):
# prepare input tensors
q, k = [x.transpose(0, 1) for x in (query, key)]
rotary_dim = cos.shape[-1]
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1)
# rotate half the sequence length
rot = cos.shape[-1] // 2
q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1)
k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True)
# apply the rotation
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# always cache the cos/sin for the full sequence length to avoid
@ -628,6 +627,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
# double the size and add a batch dimension
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0)
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0)
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1)
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1)
return cos, sin

View File

@ -1400,11 +1400,11 @@ class FlashCausalLM(Model):
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
if hasattr(self.model, "get_position_ids"):
# use model specific position ids for initialization
position_ids = self.model.get_position_ids(input_ids)
else:
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
# mrope have position_ids per section, if so repeat n times
if self.model.config.rope_scaling["rope_type"] == "mrope":
n_sections = len(self.model.config.rope_scaling["mrope_section"])
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s