mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
change arange
This commit is contained in:
parent
9cc16725bf
commit
9775facbf7
@ -267,13 +267,22 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
if seqlen > self.original_max_position_embeddings:
|
||||
inv_freq = self.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.short_inv_freq
|
||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||
short_freqs = torch.outer(
|
||||
t[: self.original_max_position_embeddings],
|
||||
self.short_inv_freq.to(device=t.device),
|
||||
)
|
||||
long_freqs = torch.outer(
|
||||
t[self.original_max_position_embeddings :],
|
||||
self.long_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
freqs = torch.cat([short_freqs, long_freqs])
|
||||
|
||||
from loguru import logger
|
||||
|
||||
logger.info(freqs.shape)
|
||||
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user