fix: consolidate long rope paths

This commit is contained in:
drbh 2024-09-06 14:17:56 +00:00 committed by Daniël de Kok
parent b1026a84cb
commit dad070b1fc

View File

@ -89,43 +89,6 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear": if rope_type == "linear":
pass pass
elif rope_type == "longrope":
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
original_max_position_embeddings = (
config.original_max_position_embeddings
)
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=1.0
/ (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
long_inv_freq=1.0
/ (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
elif rope_type == "dynamic": elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding( return DynamicPositionRotaryEmbedding(
@ -203,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
1 + math.log(scale) / math.log(original_max_position_embeddings) 1 + math.log(scale) / math.log(original_max_position_embeddings)
) )
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding( return SuRotaryEmbedding(
short_inv_freq=short_inv_freq, short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq, long_inv_freq=long_inv_freq,