mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Some type annotations
This commit is contained in:
parent
0cafbf3b54
commit
245d6d8f7c
@ -301,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached is None
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
@ -325,12 +326,12 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
short_inv_freq,
|
short_inv_freq: torch.Tensor,
|
||||||
long_inv_freq,
|
long_inv_freq: torch.Tensor,
|
||||||
max_position_embeddings,
|
max_position_embeddings: int,
|
||||||
short_mscale,
|
short_mscale: float,
|
||||||
long_mscale,
|
long_mscale: float,
|
||||||
original_max_position_embeddings,
|
original_max_position_embeddings: int,
|
||||||
):
|
):
|
||||||
super(PositionRotaryEmbedding, self).__init__()
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
self.short_inv_freq = short_inv_freq
|
self.short_inv_freq = short_inv_freq
|
||||||
@ -538,7 +539,6 @@ def apply_llama3_scaling(
|
|||||||
elif wavelen > low_freq_wavelen:
|
elif wavelen > low_freq_wavelen:
|
||||||
new_freqs.append(freq / scaling_factor)
|
new_freqs.append(freq / scaling_factor)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
assert low_freq_wavelen != high_freq_wavelen
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||||
high_freq_factor - low_freq_factor
|
high_freq_factor - low_freq_factor
|
||||||
|
Loading…
Reference in New Issue
Block a user