mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Some type annotations
This commit is contained in:
parent
0cafbf3b54
commit
245d6d8f7c
@ -301,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
@ -325,12 +326,12 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq,
|
||||
long_inv_freq,
|
||||
max_position_embeddings,
|
||||
short_mscale,
|
||||
long_mscale,
|
||||
original_max_position_embeddings,
|
||||
short_inv_freq: torch.Tensor,
|
||||
long_inv_freq: torch.Tensor,
|
||||
max_position_embeddings: int,
|
||||
short_mscale: float,
|
||||
long_mscale: float,
|
||||
original_max_position_embeddings: int,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
@ -538,7 +539,6 @@ def apply_llama3_scaling(
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scaling_factor)
|
||||
else:
|
||||
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
|
Loading…
Reference in New Issue
Block a user