mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Adding Rope scaling.
This commit is contained in:
parent
92bb56b0c1
commit
edbba4ea36
@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum RopeScaling{
|
||||||
|
Linear,
|
||||||
|
Dynamic,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for RopeScaling {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
RopeScaling::Linear => {
|
||||||
|
write!(f, "linear")
|
||||||
|
}
|
||||||
|
RopeScaling::Dynamic => {
|
||||||
|
write!(f, "dynamic")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
@ -250,6 +270,27 @@ struct Args {
|
|||||||
#[clap(default_value = "1.0", long, env)]
|
#[clap(default_value = "1.0", long, env)]
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
|
|
||||||
|
/// Rope scaling will only be used for RoPE models
|
||||||
|
/// and allow rescaling the position rotary to accomodate for
|
||||||
|
/// larger prompts.
|
||||||
|
///
|
||||||
|
/// Goes together with `rope_factor`.
|
||||||
|
///
|
||||||
|
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
|
||||||
|
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
|
||||||
|
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
|
||||||
|
/// basically)
|
||||||
|
///
|
||||||
|
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
|
||||||
|
#[clap(long, env)]
|
||||||
|
rope_scaling: Option<RopeScaling>,
|
||||||
|
|
||||||
|
/// Rope scaling will only be used for RoPE models
|
||||||
|
/// See `rope_scaling`
|
||||||
|
#[clap(long, env)]
|
||||||
|
rope_factor: Option<f32>,
|
||||||
|
|
||||||
|
|
||||||
/// Outputs the logs in JSON format (useful for telemetry)
|
/// Outputs the logs in JSON format (useful for telemetry)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
@ -305,6 +346,8 @@ fn shard_manager(
|
|||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
|
rope_scaling: Option<RopeScaling>,
|
||||||
|
rope_factor: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
@ -358,6 +401,12 @@ fn shard_manager(
|
|||||||
shard_args.push(revision)
|
shard_args.push(revision)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let rope = match (rope_scaling, rope_factor) {
|
||||||
|
(None, None) => None,
|
||||||
|
(Some(scaling), None) => Some((scaling, 1.0)),
|
||||||
|
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||||
|
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||||
|
};
|
||||||
// OpenTelemetry
|
// OpenTelemetry
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
shard_args.push("--otlp-endpoint".to_string());
|
shard_args.push("--otlp-endpoint".to_string());
|
||||||
@ -395,6 +444,16 @@ fn shard_manager(
|
|||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Detect rope scaling
|
||||||
|
// Sending as env instead of CLI args to not bloat everything
|
||||||
|
// those only can be used by RoPE models, so passing information around
|
||||||
|
// for all models will complexify code unnecessarily
|
||||||
|
if let Some((scaling, factor)) = rope{
|
||||||
|
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
|
||||||
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
@ -784,6 +843,8 @@ fn spawn_shards(
|
|||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
|
let rope_scaling = args.rope_scaling;
|
||||||
|
let rope_factor = args.rope_factor;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
@ -802,6 +863,8 @@ fn spawn_shards(
|
|||||||
watermark_gamma,
|
watermark_gamma,
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
|
rope_scaling,
|
||||||
|
rope_factor,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
@ -186,7 +186,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
@ -381,33 +381,65 @@ try:
|
|||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
|
||||||
class PositionRotaryEmbedding(nn.Module):
|
def _create_inv_freq(dim, base, device):
|
||||||
def __init__(self, inv_freq):
|
inv_freq = 1.0 / (
|
||||||
super().__init__()
|
base
|
||||||
|
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _get_rope_config(config):
|
||||||
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
|
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||||
|
return rope_scaling
|
||||||
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, inv_freq, scaling_factor):
|
||||||
|
super().__init__()
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.dynamic_args = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
base
|
scaling_factor = None
|
||||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
rope_scaling = _get_rope_config(config)
|
||||||
)
|
if rope_scaling is not None:
|
||||||
return cls(inv_freq)
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_scaling["type"] == "dynamic":
|
||||||
|
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||||
|
return cls(inv_freq, scaling_factor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights):
|
def load(cls, config, prefix, weights):
|
||||||
# XXX: Always load this in float32 !
|
# XXX: Always load this in float32 !
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
weights.dtype = dtype
|
weights.dtype = dtype
|
||||||
return cls(inv_freq)
|
|
||||||
|
scaling_factor = None
|
||||||
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_scaling["type"] == "dynamic":
|
||||||
|
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||||
|
return cls(inv_freq, scaling_factor)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
@ -419,8 +451,11 @@ try:
|
|||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
t /= self.scaling_factor
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
@ -446,5 +481,36 @@ try:
|
|||||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
|
inv_freq = create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(inv_freq, scaling_factor)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings:
|
||||||
|
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
||||||
|
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
t /= self.scaling_factor
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user