diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 713cdf06..a2076bb2 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -301,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached + or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): @@ -325,12 +326,12 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, - short_inv_freq, - long_inv_freq, - max_position_embeddings, - short_mscale, - long_mscale, - original_max_position_embeddings, + short_inv_freq: torch.Tensor, + long_inv_freq: torch.Tensor, + max_position_embeddings: int, + short_mscale: float, + long_mscale: float, + original_max_position_embeddings: int, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq @@ -538,7 +539,6 @@ def apply_llama3_scaling( elif wavelen > low_freq_wavelen: new_freqs.append(freq / scaling_factor) else: - assert low_freq_wavelen != high_freq_wavelen smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( high_freq_factor - low_freq_factor