diff --git a/proto/generate.proto b/proto/generate.proto index 28a61362..dccd7e59 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -40,6 +40,8 @@ message NextTokenChooserParameters { uint64 seed = 5; /// repetition penalty float repetition_penalty = 6; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 7; } message StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index 78f9efd1..1386f6b5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters { #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] + #[schema(default = "false", example = true)] + pub watermark: bool, + #[serde(default)] #[schema(default = "true")] pub details: bool, #[serde(default)] @@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters { do_sample: false, max_new_tokens: default_max_new_tokens(), return_full_text: None, - stop: vec![], + stop: Vec::new(), + watermark: false, details: false, seed: None, } diff --git a/router/src/server.rs b/router/src/server.rs index 75f84ba5..8c6aa93f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -72,6 +72,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json "NextTokenChooser": return NextTokenChooser( + vocab_size=vocab_size, + watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, top_k=pb.top_k, diff --git a/server/text_generation/utils/watermark.py b/server/text_generation/utils/watermark.py new file mode 100644 index 00000000..abfe5f89 --- /dev/null +++ b/server/text_generation/utils/watermark.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2023 Authors of "A Watermark for Large Language Models" +# available at https://arxiv.org/abs/2301.10226 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import LogitsProcessor + + +class WatermarkLogitsProcessor(LogitsProcessor): + def __init__( + self, + vocab_size: int, + gamma: float = 0.5, + delta: float = 2.0, + hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width + device: str = "cpu", + ): + # watermarking parameters + self.vocab_size = vocab_size + self.gamma = gamma + self.delta = delta + self.rng = torch.Generator(device=device) + self.hash_key = hash_key + + def _seed_rng(self, input_ids: torch.LongTensor) -> None: + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1].item() + self.rng.manual_seed(self.hash_key * prev_token) + + def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: + # seed the rng using the previous tokens/prefix + self._seed_rng(input_ids) + + greenlist_size = int(self.vocab_size * self.gamma) + vocab_permutation = torch.randperm( + self.vocab_size, device=input_ids.device, generator=self.rng + ) + greenlist_ids = vocab_permutation[:greenlist_size] + return greenlist_ids + + @staticmethod + def _calc_greenlist_mask( + scores: torch.FloatTensor, greenlist_token_ids + ) -> torch.BoolTensor: + green_tokens_mask = torch.zeros_like(scores) + green_tokens_mask[-1, greenlist_token_ids] = 1 + final_mask = green_tokens_mask.bool() + return final_mask + + @staticmethod + def _bias_greenlist_logits( + scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float + ) -> torch.Tensor: + scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias + return scores + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + assert len(input_ids) == 1 + greenlist_ids = self._get_greenlist_ids(input_ids[0]) + green_tokens_mask = self._calc_greenlist_mask( + scores=scores, greenlist_token_ids=greenlist_ids + ) + + scores = self._bias_greenlist_logits( + scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta + ) + return scores