From 2c3b0789117f5e6701b9932307d05a85233c6fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 23 Jul 2024 14:34:56 +0000 Subject: [PATCH] Add support for Llama 3 rotary embeddings --- .../text_generation_server/layers/rotary.py | 60 ++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index db78ee1c..e6e96abb 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,4 +1,5 @@ import os +import math import torch from torch import nn from loguru import logger @@ -85,9 +86,13 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor = None rope_scaling = _get_rope_config(config) if rope_scaling is not None: - if rope_scaling["type"] == "linear": + # `rope_type` is now standard in transformers, but some existing models + # have `type` instead. + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + if rope_type == "linear": pass - elif rope_scaling["type"] == "dynamic": + elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( dim=dim, @@ -96,7 +101,20 @@ class PositionRotaryEmbedding(nn.Module): device=inv_freq.device, scaling_factor=scaling_factor, ) - elif rope_scaling["type"] == "yarn": + elif rope_type == "llama3": + inv_freq = apply_llama3_scaling( + inv_freq, + scaling_factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + ) + + return cls(inv_freq, scaling_factor) + + elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] mscale = rope_scaling.get("mscale", 1.0) mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) @@ -115,7 +133,7 @@ class PositionRotaryEmbedding(nn.Module): mscale=mscale, mscale_all_dim=mscale_all_dim, ) - elif rope_scaling["type"] in ["su", "longrope"]: + elif rope_type in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) @@ -327,10 +345,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): self._sin_cached = torch.sin(freqs).to(dtype) -# Inverse dim formula to find dim based on number of rotations -import math - - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) @@ -434,3 +448,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + + +def apply_llama3_scaling( + freqs: torch.Tensor, + *, + scaling_factor: int, + low_freq_factor: int, + high_freq_factor: int, + original_max_position_embeddings: int, +): + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scaling_factor) + else: + + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)