mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
implement dyno rope
This commit is contained in:
parent
50fd663b9b
commit
de7748ff41
@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -185,8 +187,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=config.rope_theta, device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights, config=config,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
dim=self.head_size, base=10000.0, device=weights.device, config=config,
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, base=10000.0, device=weights.device
|
||||
self.head_size, base=10000.0, device=weights.device, config=config,
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
|
@ -26,8 +26,6 @@ try:
|
||||
except ImportError:
|
||||
HAS_EXLLAMA = False
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
@ -157,7 +155,7 @@ def get_linear(weight, bias, quantize):
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
@ -376,11 +374,23 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
return 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
|
||||
self.inv_freq = inv_freq
|
||||
self.scaling_factor = scaling_factor
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
@ -388,38 +398,66 @@ try:
|
||||
self._sin_k_cached = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return cls(inv_freq)
|
||||
def static(cls, dim, base: float, device, config):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rope_type = rope_scaling["type"]
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_type == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=dim, max_position_embeddings=config.max_position_embeddings,
|
||||
base=base, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_type} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, prefix, weights):
|
||||
def load(cls, prefix: str, weights, config):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return cls(inv_freq)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=10000.0, device=inv_freq.device,
|
||||
scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if self._should_update(device, dtype, seqlen):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
self._set_cos_sin_cache(dtype, t)
|
||||
|
||||
def _set_cos_sin_cache(self, dtype, t):
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def _should_update(self, device, dtype, seqlen):
|
||||
return seqlen > self._seq_len_cached or self._cos_cached.device != device \
|
||||
or self._cos_cached.dtype != dtype
|
||||
|
||||
def get_cos_sin(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
@ -441,5 +479,26 @@ try:
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if self.should_update(device, dtype, seqlen):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * (
|
||||
(self.scaling_factor * seqlen / self.max_position_embeddings) -
|
||||
(self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
self._set_cos_sin_cache(dtype, t)
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user