mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
change arange
This commit is contained in:
parent
9cc16725bf
commit
9775facbf7
@ -267,13 +267,22 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
if seqlen > self.original_max_position_embeddings:
|
|
||||||
inv_freq = self.long_inv_freq
|
|
||||||
else:
|
|
||||||
inv_freq = self.short_inv_freq
|
|
||||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
|
short_freqs = torch.outer(
|
||||||
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.cat([short_freqs, long_freqs])
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(freqs.shape)
|
||||||
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user