change arange

This commit is contained in:
OlivierDehaene 2024-06-12 17:47:46 +02:00
parent 9cc16725bf
commit 9775facbf7

View File

@ -267,13 +267,22 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
if seqlen > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.short_inv_freq
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq.to(device=t.device))
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
freqs = torch.cat([short_freqs, long_freqs])
from loguru import logger
logger.info(freqs.shape)
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)