Adding Rope scaling.

This commit is contained in:
Nicolas Patry 2023-07-31 11:55:44 +00:00
parent 92bb56b0c1
commit edbba4ea36
5 changed files with 144 additions and 15 deletions

View File

@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling{
Linear,
Dynamic,
}
impl std::fmt::Display for RopeScaling {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
RopeScaling::Linear => {
write!(f, "linear")
}
RopeScaling::Dynamic => {
write!(f, "dynamic")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -250,6 +270,27 @@ struct Args {
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,
/// Rope scaling will only be used for RoPE models
/// and allow rescaling the position rotary to accomodate for
/// larger prompts.
///
/// Goes together with `rope_factor`.
///
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
/// basically)
///
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
#[clap(long, env)]
rope_scaling: Option<RopeScaling>,
/// Rope scaling will only be used for RoPE models
/// See `rope_scaling`
#[clap(long, env)]
rope_factor: Option<f32>,
/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
@ -305,6 +346,8 @@ fn shard_manager(
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -358,6 +401,12 @@ fn shard_manager(
shard_args.push(revision)
}
let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
@ -395,6 +444,16 @@ fn shard_manager(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope{
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -784,6 +843,8 @@ fn spawn_shards(
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
thread::spawn(move || {
shard_manager(
model_id,
@ -802,6 +863,8 @@ fn spawn_shards(
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
rope_scaling,
rope_factor,
otlp_endpoint,
status_sender,
shutdown,

View File

@ -186,7 +186,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size**-0.5

View File

@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -381,33 +381,65 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, prefix, weights):
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -419,8 +451,11 @@ try:
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
@ -446,5 +481,36 @@ try:
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError:
pass