implement dyno rope

This commit is contained in:
Yessen Kanapin 2023-08-25 13:20:31 -07:00
parent 50fd663b9b
commit de7748ff41
4 changed files with 89 additions and 28 deletions

View File

@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
self.pretraining_tp = pretraining_tp self.pretraining_tp = pretraining_tp
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -185,8 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.static(
prefix=f"{prefix}.rotary_emb", weights=weights dim=self.head_size, base=config.rope_theta, device=weights.device,
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5

View File

@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights prefix=f"{prefix}.rotary_emb", weights=weights, config=config,
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device dim=self.head_size, base=10000.0, device=weights.device, config=config,
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device self.head_size, base=10000.0, device=weights.device, config=config,
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -26,8 +26,6 @@ try:
except ImportError: except ImportError:
HAS_EXLLAMA = False HAS_EXLLAMA = False
from typing import Optional
# Monkey patching # Monkey patching
@classmethod @classmethod
def load_layer_norm(cls, prefix, weights, eps): def load_layer_norm(cls, prefix, weights, eps):
@ -157,7 +155,7 @@ def get_linear(weight, bias, quantize):
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception: except Exception:
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated." "The passed weight is not `gptq` compatible, loader needs to be updated."
) )
if use_exllama: if use_exllama:
@ -376,11 +374,23 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
def _create_inv_freq(dim, base, device):
return 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq): def __init__(self, inv_freq, scaling_factor):
super().__init__() super().__init__()
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.scaling_factor = scaling_factor
self._seq_len_cached = 0 self._seq_len_cached = 0
self._cos_cached = None self._cos_cached = None
self._sin_cached = None self._sin_cached = None
@ -388,38 +398,66 @@ try:
self._sin_k_cached = None self._sin_k_cached = None
@classmethod @classmethod
def static(cls, dim, base, device): def static(cls, dim, base: float, device, config):
inv_freq = 1.0 / ( inv_freq = _create_inv_freq(dim, base, device)
base scaling_factor = None
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) rope_scaling = _get_rope_config(config)
) if rope_scaling is not None:
return cls(inv_freq) scaling_factor = rope_scaling["factor"]
rope_type = rope_scaling["type"]
if rope_type == "linear":
pass
elif rope_type == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=dim, max_position_embeddings=config.max_position_embeddings,
base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_type} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
@classmethod @classmethod
def load(cls, prefix, weights): def load(cls, prefix: str, weights, config):
# XXX: Always load this in float32 ! # XXX: Always load this in float32 !
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq") inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype weights.dtype = dtype
return cls(inv_freq) scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0, device=inv_freq.device,
scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if self._should_update(device, dtype, seqlen):
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
self._set_cos_sin_cache(dtype, t)
def _set_cos_sin_cache(self, dtype, t):
# Don't do einsum, it converts fp32 to fp16 # Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device)) freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def _should_update(self, device, dtype, seqlen):
return seqlen > self._seq_len_cached or self._cos_cached.device != device \
or self._cos_cached.dtype != dtype
def get_cos_sin( def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
): ):
@ -441,5 +479,26 @@ try:
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
if self.should_update(device, dtype, seqlen):
if seqlen > self.max_position_embeddings:
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings) -
(self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(dtype, t)
except ImportError: except ImportError:
pass pass