From 9775facbf798c3e85a9b73aa83bd5d066cf82533 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:47:46 +0200 Subject: [PATCH] change arange --- .../text_generation_server/layers/rotary.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index acba2001..5d92ad5e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -267,13 +267,22 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq.to(device=t.device)) + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + from loguru import logger + + logger.info(freqs.shape) self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)