implement dyno rope

This commit is contained in:
Yessen Kanapin 2023-08-25 13:20:31 -07:00
parent 50fd663b9b
commit de7748ff41
4 changed files with 89 additions and 28 deletions

View File

@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
@ -185,8 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=config.rope_theta, device=weights.device,
)
self.softmax_scale = self.head_size**-0.5

View File

@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
prefix=f"{prefix}.rotary_emb", weights=weights, config=config,
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, config=config,
)
self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
self.head_size, base=10000.0, device=weights.device, config=config,
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -26,8 +26,6 @@ try:
except ImportError:
HAS_EXLLAMA = False
from typing import Optional
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
@ -157,7 +155,7 @@ def get_linear(weight, bias, quantize):
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
"The passed weight is not `gptq` compatible, loader needs to be updated."
)
if use_exllama:
@ -376,11 +374,23 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
def _create_inv_freq(dim, base, device):
return 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self.scaling_factor = scaling_factor
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
@ -388,37 +398,65 @@ try:
self._sin_k_cached = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, dim, base: float, device, config):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
rope_type = rope_scaling["type"]
if rope_type == "linear":
pass
elif rope_type == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=dim, max_position_embeddings=config.max_position_embeddings,
base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_type} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, prefix, weights):
def load(cls, prefix: str, weights, config):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0, device=inv_freq.device,
scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if self._should_update(device, dtype, seqlen):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
self._set_cos_sin_cache(dtype, t)
def _set_cos_sin_cache(self, dtype, t):
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def _should_update(self, device, dtype, seqlen):
return seqlen > self._seq_len_cached or self._cos_cached.device != device \
or self._cos_cached.dtype != dtype
def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
@ -436,10 +474,31 @@ try:
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
x2 = x[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
if self.should_update(device, dtype, seqlen):
if seqlen > self.max_position_embeddings:
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings) -
(self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(dtype, t)
except ImportError:
pass