mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Add support for Llama 3 rotary embeddings
This commit is contained in:
parent
9935720c87
commit
2c3b078911
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from loguru import logger
|
||||
@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
if rope_scaling["type"] == "linear":
|
||||
# `rope_type` is now standard in transformers, but some existing models
|
||||
# have `type` instead.
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
||||
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
elif rope_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=dim,
|
||||
@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
elif rope_type == "llama3":
|
||||
inv_freq = apply_llama3_scaling(
|
||||
inv_freq,
|
||||
scaling_factor=rope_scaling["factor"],
|
||||
low_freq_factor=rope_scaling["low_freq_factor"],
|
||||
high_freq_factor=rope_scaling["high_freq_factor"],
|
||||
original_max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
)
|
||||
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
elif rope_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = rope_scaling.get("mscale", 1.0)
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||
@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
mscale=mscale,
|
||||
mscale_all_dim=mscale_all_dim,
|
||||
)
|
||||
elif rope_scaling["type"] in ["su", "longrope"]:
|
||||
elif rope_type in ["su", "longrope"]:
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
import math
|
||||
|
||||
|
||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||
|
||||
|
||||
def apply_llama3_scaling(
|
||||
freqs: torch.Tensor,
|
||||
*,
|
||||
scaling_factor: int,
|
||||
low_freq_factor: int,
|
||||
high_freq_factor: int,
|
||||
original_max_position_embeddings: int,
|
||||
):
|
||||
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||
new_freqs = []
|
||||
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scaling_factor)
|
||||
else:
|
||||
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
Loading…
Reference in New Issue
Block a user