Some type annotations

This commit is contained in:
Daniël de Kok 2024-09-25 09:31:19 +00:00
parent 0cafbf3b54
commit 245d6d8f7c

View File

@ -301,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
@ -325,12 +326,12 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__( def __init__(
self, self,
short_inv_freq, short_inv_freq: torch.Tensor,
long_inv_freq, long_inv_freq: torch.Tensor,
max_position_embeddings, max_position_embeddings: int,
short_mscale, short_mscale: float,
long_mscale, long_mscale: float,
original_max_position_embeddings, original_max_position_embeddings: int,
): ):
super(PositionRotaryEmbedding, self).__init__() super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq self.short_inv_freq = short_inv_freq
@ -538,7 +539,6 @@ def apply_llama3_scaling(
elif wavelen > low_freq_wavelen: elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor) new_freqs.append(freq / scaling_factor)
else: else:
assert low_freq_wavelen != high_freq_wavelen assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor high_freq_factor - low_freq_factor