diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 218e3f3a..ca1f6738 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -55,6 +55,10 @@ struct Args { otlp_endpoint: Option, #[clap(long, env)] cors_allow_origin: Vec, + #[clap(long, env)] + watermark_gamma: Option, + #[clap(long, env)] + watermark_delta: Option, } fn main() -> ExitCode { @@ -88,6 +92,8 @@ fn main() -> ExitCode { json_output, otlp_endpoint, cors_allow_origin, + watermark_gamma, + watermark_delta, } = args; // Signal handler @@ -243,6 +249,8 @@ fn main() -> ExitCode { huggingface_hub_cache, weights_cache_override, disable_custom_kernels, + watermark_gamma, + watermark_delta, otlp_endpoint, status_sender, shutdown, @@ -414,6 +422,8 @@ fn shard_manager( huggingface_hub_cache: Option, weights_cache_override: Option, disable_custom_kernels: bool, + watermark_gamma: Option, + watermark_delta: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc>, @@ -494,6 +504,16 @@ fn shard_manager( env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) } + // Watermark Gamma + if let Some(watermark_gamma) = watermark_gamma { + env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) + } + + // Watermark Delta + if let Some(watermark_delta) = watermark_delta { + env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) + } + // Start process tracing::info!("Starting shard {rank}"); let mut p = match Popen::create( diff --git a/proto/generate.proto b/proto/generate.proto index 28a61362..dccd7e59 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -40,6 +40,8 @@ message NextTokenChooserParameters { uint64 seed = 5; /// repetition penalty float repetition_penalty = 6; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 7; } message StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index 78f9efd1..1386f6b5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters { #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] + #[schema(default = "false", example = true)] + pub watermark: bool, + #[serde(default)] #[schema(default = "true")] pub details: bool, #[serde(default)] @@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters { do_sample: false, max_new_tokens: default_max_new_tokens(), return_full_text: None, - stop: vec![], + stop: Vec::new(), + watermark: false, details: false, seed: None, } diff --git a/router/src/queue.rs b/router/src/queue.rs index 8962aaec..088bdd3c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -234,6 +234,7 @@ mod tests { do_sample: false, seed: 0, repetition_penalty: 0.0, + watermark: false }, stopping_parameters: StoppingCriteriaParameters { max_new_tokens: 0, diff --git a/router/src/server.rs b/router/src/server.rs index 5d4140ed..247c55b3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -72,6 +72,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json "NextTokenChooser": return NextTokenChooser( + vocab_size=vocab_size, + watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, top_k=pb.top_k, diff --git a/server/text_generation/utils/watermark.py b/server/text_generation/utils/watermark.py new file mode 100644 index 00000000..6f5664fe --- /dev/null +++ b/server/text_generation/utils/watermark.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2023 Authors of "A Watermark for Large Language Models" +# available at https://arxiv.org/abs/2301.10226 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from transformers import LogitsProcessor + +GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) +DELTA = os.getenv("WATERMARK_DELTA", 2.0) + + +class WatermarkLogitsProcessor(LogitsProcessor): + def __init__( + self, + vocab_size: int, + gamma: float = GAMMA, + delta: float = DELTA, + hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width + device: str = "cpu", + ): + # watermarking parameters + self.vocab_size = vocab_size + self.gamma = gamma + self.delta = delta + self.rng = torch.Generator(device=device) + self.hash_key = hash_key + + def _seed_rng(self, input_ids: torch.LongTensor) -> None: + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1].item() + self.rng.manual_seed(self.hash_key * prev_token) + + def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: + # seed the rng using the previous tokens/prefix + self._seed_rng(input_ids) + + greenlist_size = int(self.vocab_size * self.gamma) + vocab_permutation = torch.randperm( + self.vocab_size, device=input_ids.device, generator=self.rng + ) + greenlist_ids = vocab_permutation[:greenlist_size] + return greenlist_ids + + @staticmethod + def _calc_greenlist_mask( + scores: torch.FloatTensor, greenlist_token_ids + ) -> torch.BoolTensor: + green_tokens_mask = torch.zeros_like(scores) + green_tokens_mask[-1, greenlist_token_ids] = 1 + final_mask = green_tokens_mask.bool() + return final_mask + + @staticmethod + def _bias_greenlist_logits( + scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float + ) -> torch.Tensor: + scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias + return scores + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + assert len(input_ids) == 1 + greenlist_ids = self._get_greenlist_ids(input_ids[0]) + green_tokens_mask = self._calc_greenlist_mask( + scores=scores, greenlist_token_ids=greenlist_ids + ) + + scores = self._bias_greenlist_logits( + scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta + ) + return scores