From 0be0506a7ebfa6072d5c3019e8fa0ef51e47a899 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 1 Mar 2023 17:14:44 +0100 Subject: [PATCH] add option to set the watermark gamma & delta from the launcher --- launcher/src/main.rs | 20 +++++++++++++++++++ server/text_generation/utils/watermark.py | 24 +++++++++++++---------- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 218e3f3a..ca1f6738 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -55,6 +55,10 @@ struct Args { otlp_endpoint: Option, #[clap(long, env)] cors_allow_origin: Vec, + #[clap(long, env)] + watermark_gamma: Option, + #[clap(long, env)] + watermark_delta: Option, } fn main() -> ExitCode { @@ -88,6 +92,8 @@ fn main() -> ExitCode { json_output, otlp_endpoint, cors_allow_origin, + watermark_gamma, + watermark_delta, } = args; // Signal handler @@ -243,6 +249,8 @@ fn main() -> ExitCode { huggingface_hub_cache, weights_cache_override, disable_custom_kernels, + watermark_gamma, + watermark_delta, otlp_endpoint, status_sender, shutdown, @@ -414,6 +422,8 @@ fn shard_manager( huggingface_hub_cache: Option, weights_cache_override: Option, disable_custom_kernels: bool, + watermark_gamma: Option, + watermark_delta: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc>, @@ -494,6 +504,16 @@ fn shard_manager( env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) } + // Watermark Gamma + if let Some(watermark_gamma) = watermark_gamma { + env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) + } + + // Watermark Delta + if let Some(watermark_delta) = watermark_delta { + env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) + } + // Start process tracing::info!("Starting shard {rank}"); let mut p = match Popen::create( diff --git a/server/text_generation/utils/watermark.py b/server/text_generation/utils/watermark.py index abfe5f89..6f5664fe 100644 --- a/server/text_generation/utils/watermark.py +++ b/server/text_generation/utils/watermark.py @@ -13,19 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import torch from transformers import LogitsProcessor +GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) +DELTA = os.getenv("WATERMARK_DELTA", 2.0) + class WatermarkLogitsProcessor(LogitsProcessor): def __init__( - self, - vocab_size: int, - gamma: float = 0.5, - delta: float = 2.0, - hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width - device: str = "cpu", + self, + vocab_size: int, + gamma: float = GAMMA, + delta: float = DELTA, + hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width + device: str = "cpu", ): # watermarking parameters self.vocab_size = vocab_size @@ -36,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): def _seed_rng(self, input_ids: torch.LongTensor) -> None: assert ( - input_ids.shape[-1] >= 1 + input_ids.shape[-1] >= 1 ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) @@ -54,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): @staticmethod def _calc_greenlist_mask( - scores: torch.FloatTensor, greenlist_token_ids + scores: torch.FloatTensor, greenlist_token_ids ) -> torch.BoolTensor: green_tokens_mask = torch.zeros_like(scores) green_tokens_mask[-1, greenlist_token_ids] = 1 @@ -63,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor): @staticmethod def _bias_greenlist_logits( - scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float + scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float ) -> torch.Tensor: scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias return scores def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: assert len(input_ids) == 1 greenlist_ids = self._get_greenlist_ids(input_ids[0])