From de7748ff419f949f9cc3b96b35a3d01e57cc8d69 Mon Sep 17 00:00:00 2001 From: Yessen Kanapin Date: Fri, 25 Aug 2023 13:20:31 -0700 Subject: [PATCH] implement dyno rope --- .../custom_modeling/flash_llama_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_rw_modeling.py | 4 +- server/text_generation_server/utils/layers.py | 105 ++++++++++++++---- 4 files changed, 89 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b6285856..66e2fce0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_scaling=None, + rope_theta=10000.0, **kwargs, ): self.vocab_size = vocab_size @@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig): self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_scaling = rope_scaling + self.rope_theta = rope_theta super().__init__( pad_token_id=pad_token_id, @@ -185,8 +187,8 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + dim=self.head_size, base=config.rope_theta, device=weights.device, ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index e7c8ced4..5be65fa3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = self.num_heads // weights.process_group.size() self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + prefix=f"{prefix}.rotary_emb", weights=weights, config=config, ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 3570b283..e23edf3a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module): self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( - dim=self.head_size, base=10000.0, device=weights.device + dim=self.head_size, base=10000.0, device=weights.device, config=config, ) self.softmax_scale = self.head_size ** (-0.5) @@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.head_size = hidden_size // num_heads self.rotary_emb = PositionRotaryEmbedding.static( - self.head_size, base=10000.0, device=weights.device + self.head_size, base=10000.0, device=weights.device, config=config, ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7a45808e..30b4886c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -26,8 +26,6 @@ try: except ImportError: HAS_EXLLAMA = False -from typing import Optional - # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -157,7 +155,7 @@ def get_linear(weight, bias, quantize): qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight except Exception: raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." + "The passed weight is not `gptq` compatible, loader needs to be updated." ) if use_exllama: @@ -376,11 +374,23 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb + def _create_inv_freq(dim, base, device): + return 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + return rope_scaling + return getattr(config, "rope_scaling", None) + class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq): + def __init__(self, inv_freq, scaling_factor): super().__init__() self.inv_freq = inv_freq + self.scaling_factor = scaling_factor self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None @@ -388,37 +398,65 @@ try: self._sin_k_cached = None @classmethod - def static(cls, dim, base, device): - inv_freq = 1.0 / ( - base - ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - return cls(inv_freq) + def static(cls, dim, base: float, device, config): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + rope_type = rope_scaling["type"] + if rope_type == "linear": + pass + elif rope_type == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=dim, max_position_embeddings=config.max_position_embeddings, + base=base, device=inv_freq.device, scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_type} is not implemented or invalid") + return cls(inv_freq, scaling_factor) @classmethod - def load(cls, prefix, weights): + def load(cls, prefix: str, weights, config): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype - return cls(inv_freq) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, device=inv_freq.device, + scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): + if self._should_update(device, dtype, seqlen): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + self._set_cos_sin_cache(dtype, t) + + def _set_cos_sin_cache(self, dtype, t): + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def _should_update(self, device, dtype, seqlen): + return seqlen > self._seq_len_cached or self._cos_cached.device != device \ + or self._cos_cached.dtype != dtype def get_cos_sin( self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype @@ -436,10 +474,31 @@ try: def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): rotary_dim = cos.shape[-1] x1 = x[..., :rotary_dim] - x2 = x[..., rotary_dim : 2 * rotary_dim] + x2 = x[..., rotary_dim: 2 * rotary_dim] rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) return x + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + if self.should_update(device, dtype, seqlen): + if seqlen > self.max_position_embeddings: + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) - + (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache(dtype, t) + except ImportError: pass