mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Add support for Llama 3 rotary embeddings
This commit is contained in:
parent
9935720c87
commit
2c3b078911
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
scaling_factor = None
|
scaling_factor = None
|
||||||
rope_scaling = _get_rope_config(config)
|
rope_scaling = _get_rope_config(config)
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
if rope_scaling["type"] == "linear":
|
# `rope_type` is now standard in transformers, but some existing models
|
||||||
|
# have `type` instead.
|
||||||
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
||||||
|
|
||||||
|
if rope_type == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_scaling["type"] == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_type == "llama3":
|
||||||
|
inv_freq = apply_llama3_scaling(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor=rope_scaling["factor"],
|
||||||
|
low_freq_factor=rope_scaling["low_freq_factor"],
|
||||||
|
high_freq_factor=rope_scaling["high_freq_factor"],
|
||||||
|
original_max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(inv_freq, scaling_factor)
|
||||||
|
|
||||||
|
elif rope_type == "yarn":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
mscale = rope_scaling.get("mscale", 1.0)
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
mscale=mscale,
|
mscale=mscale,
|
||||||
mscale_all_dim=mscale_all_dim,
|
mscale_all_dim=mscale_all_dim,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] in ["su", "longrope"]:
|
elif rope_type in ["su", "longrope"]:
|
||||||
short_factor = torch.tensor(
|
short_factor = torch.tensor(
|
||||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||||
2 * math.log(base)
|
2 * math.log(base)
|
||||||
@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_llama3_scaling(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
*,
|
||||||
|
scaling_factor: int,
|
||||||
|
low_freq_factor: int,
|
||||||
|
high_freq_factor: int,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
):
|
||||||
|
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||||
|
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / scaling_factor)
|
||||||
|
else:
|
||||||
|
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||||
|
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user