Add support for Llama 3 rotary embeddings

This commit is contained in:
Daniël de Kok 2024-07-23 14:34:56 +00:00
parent 9935720c87
commit 2c3b078911

View File

@ -1,4 +1,5 @@
import os
import math
import torch
from torch import nn
from loguru import logger
@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module):
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
if rope_scaling["type"] == "linear":
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
if rope_type == "linear":
pass
elif rope_scaling["type"] == "dynamic":
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module):
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
elif rope_type == "llama3":
inv_freq = apply_llama3_scaling(
inv_freq,
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
)
return cls(inv_freq, scaling_factor)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module):
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
elif rope_scaling["type"] in ["su", "longrope"]:
elif rope_type in ["su", "longrope"]:
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
def apply_llama3_scaling(
freqs: torch.Tensor,
*,
scaling_factor: int,
low_freq_factor: int,
high_freq_factor: int,
original_max_position_embeddings: int,
):
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)