diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index acba2001..5d92ad5e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -267,13 +267,22 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq.to(device=t.device)) + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + from loguru import logger + + logger.info(freqs.shape) self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)