mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Adding Rope scaling.
This commit is contained in:
parent
92bb56b0c1
commit
edbba4ea36
@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum RopeScaling{
|
||||
Linear,
|
||||
Dynamic,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RopeScaling {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self {
|
||||
RopeScaling::Linear => {
|
||||
write!(f, "linear")
|
||||
}
|
||||
RopeScaling::Dynamic => {
|
||||
write!(f, "dynamic")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
@ -250,6 +270,27 @@ struct Args {
|
||||
#[clap(default_value = "1.0", long, env)]
|
||||
cuda_memory_fraction: f32,
|
||||
|
||||
/// Rope scaling will only be used for RoPE models
|
||||
/// and allow rescaling the position rotary to accomodate for
|
||||
/// larger prompts.
|
||||
///
|
||||
/// Goes together with `rope_factor`.
|
||||
///
|
||||
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
|
||||
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
|
||||
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
|
||||
/// basically)
|
||||
///
|
||||
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
|
||||
#[clap(long, env)]
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
|
||||
/// Rope scaling will only be used for RoPE models
|
||||
/// See `rope_scaling`
|
||||
#[clap(long, env)]
|
||||
rope_factor: Option<f32>,
|
||||
|
||||
|
||||
/// Outputs the logs in JSON format (useful for telemetry)
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
@ -305,6 +346,8 @@ fn shard_manager(
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
@ -358,6 +401,12 @@ fn shard_manager(
|
||||
shard_args.push(revision)
|
||||
}
|
||||
|
||||
let rope = match (rope_scaling, rope_factor) {
|
||||
(None, None) => None,
|
||||
(Some(scaling), None) => Some((scaling, 1.0)),
|
||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||
};
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
shard_args.push("--otlp-endpoint".to_string());
|
||||
@ -395,6 +444,16 @@ fn shard_manager(
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// Detect rope scaling
|
||||
// Sending as env instead of CLI args to not bloat everything
|
||||
// those only can be used by RoPE models, so passing information around
|
||||
// for all models will complexify code unnecessarily
|
||||
if let Some((scaling, factor)) = rope{
|
||||
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
|
||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||
}
|
||||
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
@ -784,6 +843,8 @@ fn spawn_shards(
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -802,6 +863,8 @@ fn spawn_shards(
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
|
@ -186,7 +186,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
|
@ -381,33 +381,65 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq):
|
||||
super().__init__()
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return cls(inv_freq)
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, prefix, weights):
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return cls(inv_freq)
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -419,8 +451,11 @@ try:
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
@ -446,5 +481,36 @@ try:
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user