Adding yarn support.

This commit is contained in:
Nicolas Patry 2023-10-04 15:17:02 +00:00
parent 8ec1b87f16
commit 04323d27dd

View File

@ -601,6 +601,19 @@ try:
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
) )
elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid" f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
@ -629,6 +642,19 @@ try:
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
) )
elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid" f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
@ -708,5 +734,76 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
# Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_dim(
low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_mscale(scale=1):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
except ImportError: except ImportError:
pass pass