diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index f3ec1f62..061bf024 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -557,6 +557,8 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): def __init__(self, inv_freq, scaling_factor, sections): super().__init__(inv_freq, scaling_factor) + # expand the inv_freq for the 3 sections + self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1) self.sections = sections * 2 self._cos_cached = None self._sin_cached = None @@ -568,36 +570,41 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - mrope_section = self.sections - unsqueeze_dim = 1 - - split_cos = cos.split(mrope_section, dim=-1) - split_sin = sin.split(mrope_section, dim=-1) - - cos = [] - for i, m in enumerate(split_cos): - cos.append(m[i % 3]) - - cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim) - - sin = [] - for i, m in enumerate(split_sin): - sin.append(m[i % 3]) - - sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim) - - q = query.transpose(0, 1).unsqueeze(0) - k = key.transpose(0, 1).unsqueeze(0) - + # process multi-modal rotary embeddings + split_cos, split_sin = [ + torch.split(t, self.sections, dim=-1) for t in (cos, sin) + ] + cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze( + 1 + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze( + 1 + ) + # prepare input tensors + q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] rotary_dim = cos.shape[-1] - q1 = q[..., :rotary_dim] + q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) - - k1 = k[..., :rotary_dim] k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) + def _update_cos_sin_cache(self, dtype, device, seqlen): + # always cache the cos/sin for the full sequence length to avoid + # recomputing if the sequence length is smaller than the cached one + if ( + seqlen > self._seq_len_cached + or self._cos_cached_exp.device != device + or self._cos_cached_exp.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + freqs = freqs.expand(3, -1, -1) + self._cos_cached_exp = freqs.cos().to(dtype) + self._sin_cached_exp = freqs.sin().to(dtype) + def get_cos_sin( self, position_ids: torch.Tensor, @@ -605,21 +612,16 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - inv_freq_expanded = ( - self.inv_freq[None, None, :, None] - .float() - .expand(3, position_ids.shape[1], -1, 1) + # expand the position_ids to match the shape of the cached cos/sin + indices = ( + position_ids.squeeze(1) + .unsqueeze(-1) + .expand(-1, -1, self._cos_cached_exp.shape[-1]) ) + cos_c = torch.gather(self._cos_cached_exp, 1, indices) + cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) - position_ids_expanded = position_ids[ - :, :, None, : - ].float() # shape (3, bs, 1, positions) + sin_c = torch.gather(self._sin_cached_exp, 1, indices) + sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( - 2, 3 - ) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype), sin.to(dtype) + return cos_c, sin_c