diff --git a/Cargo.lock b/Cargo.lock index 671d615c..7962174b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3552,7 +3552,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "2.0.1" +version = "2.0.2" dependencies = [ "average", "clap", @@ -3573,7 +3573,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.0.1" +version = "2.0.2" dependencies = [ "futures", "grpc-metadata", @@ -3590,7 +3590,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.0.1" +version = "2.0.2" dependencies = [ "clap", "ctrlc", @@ -3608,7 +3608,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.0.1" +version = "2.0.2" dependencies = [ "async-stream", "axum", diff --git a/router/src/config.rs b/router/src/config.rs index 88cde69a..8640ede9 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -136,6 +136,7 @@ pub enum Config { Phi, #[serde(rename = "phi-msft")] PhiMsft, + Phi3, Llama, Baichuan, Gemma, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 6e4a13cd..7d339fe5 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1029,10 +1029,10 @@ try: scaling_factor = None rope_scaling = _get_rope_config(config) if rope_scaling is not None: - scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "linear": pass elif rope_scaling["type"] == "dynamic": + scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( dim=dim, max_position_embeddings=config.max_position_embeddings, @@ -1041,6 +1041,7 @@ try: scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + scaling_factor = rope_scaling["factor"] return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -1054,6 +1055,52 @@ try: beta_fast=32, beta_slow=1, ) + elif rope_scaling["type"] == "su": + short_factor = torch.tensor( + rope_scaling["short_factor"], dtype=torch.float32, device=device + ) + short_inv_freq = 1.0 / ( + short_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + long_factor = torch.tensor( + rope_scaling["long_factor"], dtype=torch.float32, device=device + ) + long_inv_freq = 1.0 / ( + long_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + + original_max_position_embeddings = ( + config.original_max_position_embeddings + ) + max_position_embeddings = config.max_position_embeddings + if max_position_embeddings <= original_max_position_embeddings: + scaling_factor = 1.0 + else: + scale = ( + max_position_embeddings / original_max_position_embeddings + ) + scaling_factor = math.sqrt( + 1 + + math.log(scale) + / math.log(original_max_position_embeddings) + ) + + return SuRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" @@ -1141,6 +1188,49 @@ try: # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) + class SuRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq, + long_inv_freq, + scaling_factor, + original_max_position_embeddings, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + if seqlen > self.original_max_position_embeddings: + inv_freq = self.long_inv_freq + else: + inv_freq = self.short_inv_freq + t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device)