mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Implement scaled and dynamically scaled RoPE
This commit is contained in:
parent
a2cf1bdb2f
commit
f01c11bd0c
@ -162,7 +162,7 @@ struct Args {
|
|||||||
/// Limits the number of tokens for the prefill operation.
|
/// Limits the number of tokens for the prefill operation.
|
||||||
/// Since this operation take the most memory and is compute bound, it is interesting
|
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||||
/// to limit the number of requests that can be sent.
|
/// to limit the number of requests that can be sent.
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
|
|
||||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||||
@ -182,7 +182,7 @@ struct Args {
|
|||||||
/// depends on other parameters like if you're using quantization, flash attention
|
/// depends on other parameters like if you're using quantization, flash attention
|
||||||
/// or the model implementation, text-generation-inference cannot infer this number
|
/// or the model implementation, text-generation-inference cannot infer this number
|
||||||
/// automatically.
|
/// automatically.
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(default_value = "8192", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
|
||||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||||
@ -280,6 +280,19 @@ struct Args {
|
|||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
|
|
||||||
|
/// NTK-Aware Scaled Rope is a method proposed in https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
/// The scale factor, or "α", is used in combination with a non linearity to scale the base used to calculate the parameter "θ", the angle of rotation in RoPE.
|
||||||
|
/// This increases how many input tokens can be represented within the same portion of a positional embedding, with the non linearity used to increase token seprability.
|
||||||
|
#[clap(default_value="1", long, env)]
|
||||||
|
rope_scale_factor: usize,
|
||||||
|
|
||||||
|
/// Dynamic scaling of the "α" factor in NTK-Aware Scaled Rope was introduced in https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
|
||||||
|
/// The idea being instead of setting alpha statically, it is calculated as a function of the current sequence length and the model's base sequence length.
|
||||||
|
/// This is a means to both increase performance on shorter sequence lengths and smooth the perplexity explosion experienced by both linearly scaled and NTK-Aware scaled RoPE.
|
||||||
|
/// If this is enabled the above "rope_scale_factor" will be ignored.
|
||||||
|
#[clap(default_value="false", long, env)]
|
||||||
|
rope_dynamic_scaling: bool
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -293,6 +306,8 @@ fn shard_manager(
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
rope_scale_factor: usize,
|
||||||
|
rope_dynamic_scaling: bool,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
@ -422,6 +437,10 @@ fn shard_manager(
|
|||||||
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoPE Scaling
|
||||||
|
env.push(("ROPE_SCALE_FACTOR".into(), rope_scale_factor.to_string().into()));
|
||||||
|
env.push(("ROPE_DYNAMIC_SCALING".into(), rope_dynamic_scaling.to_string().into()));
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Command::new("text-generation-server")
|
let mut p = match Command::new("text-generation-server")
|
||||||
@ -776,11 +795,16 @@ fn spawn_shards(
|
|||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
|
let rope_scale_factor = args.rope_scale_factor;
|
||||||
|
let rope_dynamic_scaling = args.rope_dynamic_scaling;
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
|
rope_scale_factor,
|
||||||
|
rope_dynamic_scaling,
|
||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
|
@ -35,9 +35,9 @@ struct Args {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(default_value = "8192", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -41,6 +42,12 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||||
|
|
||||||
|
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||||
|
ROPE_DYNAMIC_SCALING = True
|
||||||
|
else:
|
||||||
|
ROPE_DYNAMIC_SCALING = False
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
class LlamaRMSNorm(nn.Module):
|
||||||
def __init__(self, prefix, weights, eps=1e-6):
|
def __init__(self, prefix, weights, eps=1e-6):
|
||||||
@ -105,10 +112,18 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.scale_factor = ROPE_SCALE_FACTOR
|
||||||
|
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
if self.scale_factor > 1:
|
||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
# Base before scaling is 10000 per the original RoPE paper
|
||||||
)
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -45,6 +46,14 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||||
|
|
||||||
|
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||||
|
ROPE_DYNAMIC_SCALING = True
|
||||||
|
else:
|
||||||
|
ROPE_DYNAMIC_SCALING = False
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
@ -102,10 +111,18 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.scale_factor = ROPE_SCALE_FACTOR
|
||||||
|
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
if self.scale_factor > 1:
|
||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
# Base before scaling is 10000 per the original RoPE paper
|
||||||
)
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import warnings
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
@ -23,6 +25,12 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||||
|
|
||||||
|
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||||
|
ROPE_DYNAMIC_SCALING = True
|
||||||
|
else:
|
||||||
|
ROPE_DYNAMIC_SCALING = False
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
@ -113,10 +121,13 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.num_heads_kv = config.n_head_kv
|
self.num_heads_kv = config.n_head_kv
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.scale_factor = ROPE_SCALE_FACTOR
|
||||||
|
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
dim=self.head_size, base=10000.0, device=weights.device
|
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
@ -239,9 +250,11 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
self.scale_factor = ROPE_SCALE_FACTOR
|
||||||
|
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
self.head_size, base=10000.0, device=weights.device
|
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
@ -61,6 +61,14 @@ if not CUSTOM_KERNELS_ENABLED:
|
|||||||
logger.warning("We're not using custom kernels.")
|
logger.warning("We're not using custom kernels.")
|
||||||
|
|
||||||
|
|
||||||
|
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||||
|
|
||||||
|
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||||
|
ROPE_DYNAMIC_SCALING = True
|
||||||
|
else:
|
||||||
|
ROPE_DYNAMIC_SCALING = False
|
||||||
|
|
||||||
|
|
||||||
def make_causal_mask(
|
def make_causal_mask(
|
||||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
|
@ -369,7 +369,7 @@ try:
|
|||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
|
||||||
class PositionRotaryEmbedding(nn.Module):
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, inv_freq):
|
def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
@ -379,32 +379,62 @@ try:
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
|
|
||||||
@classmethod
|
self.scale_factor = scale_factor
|
||||||
def static(cls, dim, base, device):
|
self.dynamic_scaling = dynamic_scaling
|
||||||
inv_freq = 1.0 / (
|
self.original_max_seq_len = max_seq_len
|
||||||
base
|
self.max_seq_len = max_seq_len * scale_factor
|
||||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
self.dim = dim
|
||||||
)
|
self.base = base
|
||||||
return cls(inv_freq)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def static(cls, dim, base, device, scale_factor=1, dynamic_scaling=False, max_seq_len=2048):
|
||||||
|
inv_freq = cls._get_inv_freq(dim, base, device, scale_factor)
|
||||||
|
return cls(inv_freq, scale_factor, dynamic_scaling, max_seq_len, dim, base)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights):
|
def load(cls, prefix, weights):
|
||||||
# XXX: Always load this in float32 !
|
# XXX: Always load this in float32 !
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
|
|
||||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
weights.dtype = dtype
|
weights.dtype = dtype
|
||||||
return cls(inv_freq)
|
return cls(inv_freq)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_inv_freq(dim, base, device, scale_factor=1):
|
||||||
|
base = base * scale_factor ** (dim / (dim-2))
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base
|
||||||
|
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
|
||||||
|
length = seqlen
|
||||||
|
max_seq_len = self.max_seq_len
|
||||||
|
inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
if self.dynamic_scaling:
|
||||||
|
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
|
||||||
|
max_seq_len = self.original_max_seq_len * scale_factor
|
||||||
|
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
if self.scale_factor > 1:
|
||||||
|
length = max(seqlen, max_seq_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
length > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = length
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
Loading…
Reference in New Issue
Block a user