From c59fb353a0c3d39021cf8f04b7fdbef06c45c7e8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 24 May 2023 16:23:46 +0200 Subject: [PATCH] add watermarking --- .../test_bloom_560m_all_params.json | 96 ++--- .../test_flash_llama_all_params.json | 58 +-- .../test_flash_starcoder_default_params.json | 364 ++++++++++++++++-- .../test_mt0_base_all_params.json | 48 +-- integration-tests/models/test_flash_llama.py | 2 +- .../models/test_flash_starcoder.py | 2 +- integration-tests/models/test_mt0_base.py | 2 +- .../models/flash_causal_lm.py | 4 +- .../models/flash_santacoder.py | 1 - .../text_generation_server/utils/convert.py | 6 +- .../utils/logits_process.py | 35 +- server/text_generation_server/utils/tokens.py | 11 +- 12 files changed, 459 insertions(+), 170 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json index 93a95804..ace73416 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json @@ -34,65 +34,65 @@ "tokens": [ { "id": 408, - "logprob": -1.9267578, + "logprob": -0.07891846, "special": false, "text": " que" }, - { - "id": 20288, - "logprob": -2.9257812, - "special": false, - "text": " l'on" - }, - { - "id": 22255, - "logprob": -2.8964844, - "special": false, - "text": " trouve" - }, - { - "id": 1622, - "logprob": -1.1083984, - "special": false, - "text": " une" - }, - { - "id": 187079, - "logprob": -7.796875, - "special": false, - "text": " posture" - }, - { - "id": 501, - "logprob": -5.390625, - "special": false, - "text": " par" - }, - { - "id": 8741, - "logprob": -0.34936523, - "special": false, - "text": " rapport" - }, - { - "id": 693, - "logprob": 0.0, - "special": false, - "text": " à" - }, { "id": 366, - "logprob": -2.3378906, + "logprob": -1.2939453, "special": false, "text": " la" }, { - "id": 36503, - "logprob": -3.6640625, + "id": 8769, + "logprob": -0.3708496, "special": false, - "text": " pratique" + "text": " personne" + }, + { + "id": 1479, + "logprob": -2.2871094, + "special": false, + "text": " qui" + }, + { + "id": 2997, + "logprob": -0.8671875, + "special": false, + "text": " vous" + }, + { + "id": 35977, + "logprob": -1.5097656, + "special": false, + "text": " suit" + }, + { + "id": 21558, + "logprob": -0.07891846, + "special": false, + "text": " ait" + }, + { + "id": 447, + "logprob": -0.12695312, + "special": false, + "text": " un" + }, + { + "id": 78606, + "logprob": -2.21875, + "special": false, + "text": " profil" + }, + { + "id": 3899, + "logprob": -1.3535156, + "special": false, + "text": " bien" } ] }, - "generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique" + "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json index 1b6b51a3..5be2870d 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "stop_sequence", + "generated_tokens": 5, "prefill": [ { "id": 1, @@ -24,65 +24,35 @@ "tokens": [ { "id": 5229, - "logprob": -3.3085938, + "logprob": -2.5683594, "special": false, "text": " failed" }, - { - "id": 363, - "logprob": -3.984375, - "special": false, - "text": " for" - }, - { - "id": 5641, - "logprob": -6.53125, - "special": false, - "text": " IP" - }, - { - "id": 16428, - "logprob": -3.1835938, - "special": false, - "text": " Address" - }, { "id": 29901, - "logprob": -1.2324219, + "logprob": -0.45336914, "special": false, "text": ":" }, { - "id": 525, - "logprob": -2.6855469, + "id": 4829, + "logprob": -1.8408203, "special": false, - "text": " '" + "text": " Error" }, { - "id": 8516, - "logprob": -7.1601562, + "id": 297, + "logprob": -1.0556641, "special": false, - "text": "None" + "text": " in" }, { - "id": 4286, - "logprob": -2.4433594, + "id": 1243, + "logprob": 0.0, "special": false, - "text": "'." - }, - { - "id": 13, - "logprob": -0.06530762, - "special": false, - "text": "\n" - }, - { - "id": 294, - "logprob": -7.953125, - "special": false, - "text": "as" + "text": " test" } ] }, - "generated_text": "Test requestfailed for IP Address: 'None'.\nas" + "generated_text": "Test requestfailed: Error in test" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index 21bb509b..afd0b662 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "eos_token", - "generated_tokens": 12, + "finish_reason": "length", + "generated_tokens": 60, "prefill": [ { "id": 589, @@ -29,77 +29,365 @@ "tokens": [ { "id": 2262, - "logprob": -0.7451172, + "logprob": -0.042999268, "special": false, "text": "():" }, { "id": 284, - "logprob": -0.21325684, + "logprob": 0.0, "special": false, "text": "\n " }, { - "id": 5741, - "logprob": -5.734375, - "special": false, - "text": " logging" - }, - { - "id": 32, + "id": 1459, "logprob": 0.0, "special": false, - "text": "." + "text": " print" }, { - "id": 1338, - "logprob": -0.3232422, + "id": 440, + "logprob": 0.0, "special": false, - "text": "info" - }, - { - "id": 463, - "logprob": -1.0380859, - "special": false, - "text": "('" + "text": "(\"" }, { "id": 8279, - "logprob": -0.8378906, + "logprob": 0.0, "special": false, "text": "Hello" }, - { - "id": 30, - "logprob": -1.9501953, - "special": false, - "text": "," - }, { "id": 10896, - "logprob": -1.3476562, + "logprob": -0.3659668, "special": false, "text": " World" }, { - "id": 683, - "logprob": -1.796875, + "id": 657, + "logprob": -0.49804688, "special": false, - "text": "')" + "text": "\")" }, { "id": 203, - "logprob": -0.9873047, + "logprob": -0.11279297, "special": false, "text": "\n" }, { - "id": 0, - "logprob": -0.7495117, - "special": true, - "text": "<|endoftext|>" + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -0.20141602, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": -0.051635742, + "special": false, + "text": "name" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.16027832, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11442, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 313, + "logprob": -0.6328125, + "special": false, + "text": " \"" + }, + { + "id": 313, + "logprob": -1.7011719, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 596, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 490, + "logprob": 0.0, + "special": false, + "text": "))" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" } ] }, - "generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>" + "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index 3e9f3d73..024823d0 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 9, "prefill": [ { "id": 0, @@ -14,65 +14,59 @@ "tokens": [ { "id": 16017, - "logprob": -1.3505859, + "logprob": -0.30908203, "special": false, "text": " blue" }, { "id": 20495, - "logprob": -0.50439453, + "logprob": 0.0, "special": false, "text": " sky" }, { "id": 259, - "logprob": -1.2011719, + "logprob": -0.28271484, "special": false, "text": " " }, { "id": 15484, - "logprob": -2.8378906, + "logprob": -1.7929688, "special": false, "text": "appear" }, { "id": 345, - "logprob": -0.87597656, + "logprob": -0.8935547, "special": false, "text": "ed" }, { - "id": 288, - "logprob": -1.8447266, + "id": 281, + "logprob": 0.0, "special": false, - "text": " to" + "text": " in" }, { - "id": 35622, - "logprob": -7.1445312, + "id": 287, + "logprob": 0.0, "special": false, - "text": " cloud" + "text": " the" }, { - "id": 263, - "logprob": -1.2929688, + "id": 20495, + "logprob": -0.32299805, "special": false, - "text": "s" + "text": " sky" }, { - "id": 14701, - "logprob": -3.0761719, - "special": false, - "text": " above" - }, - { - "id": 751, - "logprob": -4.4375, - "special": false, - "text": " all" + "id": 1, + "logprob": 0.0, + "special": true, + "text": "" } ] }, - "generated_text": "Why is the sky blue?blue sky appeared to clouds above all" + "generated_text": "Why is the sky blue?blue sky appeared in the sky" } diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index 37468455..bf5b64ba 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 5 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 4c7393a7..c1a68d89 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 ) - assert response.details.generated_tokens == 12 + assert response.details.generated_tokens == 60 assert response == response_snapshot diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 15410f73..e347d22a 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 9 assert response == response_snapshot diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6afbc7e4..f1535a77 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -631,7 +631,9 @@ class FlashCausalLM(Model): ] = batch.input_ids[start_index + 1 : end_index] else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : end_index + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 9f837ced..482e0f54 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -231,7 +231,6 @@ class FlashSantacoderSharded(FlashSantacoder): device=device, rank=rank, world_size=world_size, - decode_buffer=1, ) @staticmethod diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index caf1a764..c43a4464 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -1,14 +1,10 @@ -import concurrent -import time import datetime import torch -from concurrent.futures import ThreadPoolExecutor from collections import defaultdict -from datetime import timedelta from loguru import logger from pathlib import Path -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file from safetensors import safe_open from typing import Dict, List diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9a738190..b04e77b2 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -2,7 +2,7 @@ import math import torch from functools import lru_cache -from typing import Optional, List +from typing import Optional, List, Dict, Union from transformers import ( LogitsWarper, @@ -45,9 +45,11 @@ class StaticWarper: self.cuda_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cuda_graph): + local_scores = self.static_scores for warper in self.warpers: - self.static_warped_scores = warper(None, self.static_scores) + local_scores = warper(None, local_scores) + self.static_warped_scores = local_scores # Compute logprobs self.static_next_logprob = torch.log_softmax( self.static_warped_scores, -1 @@ -309,3 +311,32 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): def filter(self, indices): self.mass = self.mass[indices] return self + + +class HeterogeneousProcessorWrapper(LogitsProcessor): + r""" + A wrapper for logit warpers or processors without heterogeneous parameter support. + Args: + processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + A mapping of sample indices to logit warpers or processors, to be run sequentially. + """ + + def __init__( + self, + processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + ): + self.processors = processors + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + for i, processor in self.processors.items(): + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) + return scores + + def filter(self, indices): + new_processors = {} + for i, idx in enumerate(indices): + if idx in self.processors: + new_processors[i] = self.processors[idx] + + self.processors = new_processors + return self diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 118129ae..3c3bcb68 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -18,6 +18,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, + HeterogeneousProcessorWrapper, ) @@ -168,7 +169,15 @@ class HeterogeneousNextTokenChooser: warpers = LogitsProcessorList() if any(watermark): - raise NotImplementedError("Watermarking not implemented") + warpers.append( + HeterogeneousProcessorWrapper( + { + i: WatermarkLogitsProcessor(device=device) + for i, do_watermark in enumerate(watermark) + if do_watermark + } + ) + ) if any([x != 1.0 for x in repetition_penalty]): warpers.append(