mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: consolidate long rope paths
This commit is contained in:
parent
b1026a84cb
commit
dad070b1fc
@ -89,43 +89,6 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_type == "longrope":
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
long_factor = torch.tensor(
|
||||
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
short_mscale = rope_scaling["short_mscale"]
|
||||
long_mscale = rope_scaling["long_mscale"]
|
||||
original_max_position_embeddings = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||
short_inv_freq=1.0
|
||||
/ (
|
||||
short_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
),
|
||||
long_inv_freq=1.0
|
||||
/ (
|
||||
long_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
),
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
short_mscale=short_mscale,
|
||||
long_mscale=long_mscale,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
|
||||
elif rope_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
@ -203,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||
short_mscale = rope_scaling["short_mscale"]
|
||||
long_mscale = rope_scaling["long_mscale"]
|
||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
short_mscale=short_mscale,
|
||||
long_mscale=long_mscale,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
|
||||
return SuRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
|
Loading…
Reference in New Issue
Block a user