mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: consolidate long rope paths
This commit is contained in:
parent
b1026a84cb
commit
dad070b1fc
@ -89,43 +89,6 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
if rope_type == "linear":
|
if rope_type == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_type == "longrope":
|
|
||||||
short_factor = torch.tensor(
|
|
||||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
long_factor = torch.tensor(
|
|
||||||
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
short_mscale = rope_scaling["short_mscale"]
|
|
||||||
long_mscale = rope_scaling["long_mscale"]
|
|
||||||
original_max_position_embeddings = (
|
|
||||||
config.original_max_position_embeddings
|
|
||||||
)
|
|
||||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
|
||||||
short_inv_freq=1.0
|
|
||||||
/ (
|
|
||||||
short_factor
|
|
||||||
* base
|
|
||||||
** (
|
|
||||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
|
||||||
/ dim
|
|
||||||
)
|
|
||||||
),
|
|
||||||
long_inv_freq=1.0
|
|
||||||
/ (
|
|
||||||
long_factor
|
|
||||||
* base
|
|
||||||
** (
|
|
||||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
|
||||||
/ dim
|
|
||||||
)
|
|
||||||
),
|
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
|
||||||
short_mscale=short_mscale,
|
|
||||||
long_mscale=long_mscale,
|
|
||||||
original_max_position_embeddings=original_max_position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
@ -203,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||||
|
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||||
|
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||||
|
short_mscale = rope_scaling["short_mscale"]
|
||||||
|
long_mscale = rope_scaling["long_mscale"]
|
||||||
|
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
short_mscale=short_mscale,
|
||||||
|
long_mscale=long_mscale,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
return SuRotaryEmbedding(
|
return SuRotaryEmbedding(
|
||||||
short_inv_freq=short_inv_freq,
|
short_inv_freq=short_inv_freq,
|
||||||
long_inv_freq=long_inv_freq,
|
long_inv_freq=long_inv_freq,
|
||||||
|
Loading…
Reference in New Issue
Block a user