add watermarking

This commit is contained in:
OlivierDehaene 2023-05-24 16:23:46 +02:00
parent b9ad3acc4e
commit c59fb353a0
12 changed files with 459 additions and 170 deletions

View File

@ -34,65 +34,65 @@
"tokens": [
{
"id": 408,
"logprob": -1.9267578,
"logprob": -0.07891846,
"special": false,
"text": " que"
},
{
"id": 20288,
"logprob": -2.9257812,
"special": false,
"text": " l'on"
},
{
"id": 22255,
"logprob": -2.8964844,
"special": false,
"text": " trouve"
},
{
"id": 1622,
"logprob": -1.1083984,
"special": false,
"text": " une"
},
{
"id": 187079,
"logprob": -7.796875,
"special": false,
"text": " posture"
},
{
"id": 501,
"logprob": -5.390625,
"special": false,
"text": " par"
},
{
"id": 8741,
"logprob": -0.34936523,
"special": false,
"text": " rapport"
},
{
"id": 693,
"logprob": 0.0,
"special": false,
"text": " à"
},
{
"id": 366,
"logprob": -2.3378906,
"logprob": -1.2939453,
"special": false,
"text": " la"
},
{
"id": 36503,
"logprob": -3.6640625,
"id": 8769,
"logprob": -0.3708496,
"special": false,
"text": " pratique"
"text": " personne"
},
{
"id": 1479,
"logprob": -2.2871094,
"special": false,
"text": " qui"
},
{
"id": 2997,
"logprob": -0.8671875,
"special": false,
"text": " vous"
},
{
"id": 35977,
"logprob": -1.5097656,
"special": false,
"text": " suit"
},
{
"id": 21558,
"logprob": -0.07891846,
"special": false,
"text": " ait"
},
{
"id": 447,
"logprob": -0.12695312,
"special": false,
"text": " un"
},
{
"id": 78606,
"logprob": -2.21875,
"special": false,
"text": " profil"
},
{
"id": 3899,
"logprob": -1.3535156,
"special": false,
"text": " bien"
}
]
},
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 1,
@ -24,65 +24,35 @@
"tokens": [
{
"id": 5229,
"logprob": -3.3085938,
"logprob": -2.5683594,
"special": false,
"text": " failed"
},
{
"id": 363,
"logprob": -3.984375,
"special": false,
"text": " for"
},
{
"id": 5641,
"logprob": -6.53125,
"special": false,
"text": " IP"
},
{
"id": 16428,
"logprob": -3.1835938,
"special": false,
"text": " Address"
},
{
"id": 29901,
"logprob": -1.2324219,
"logprob": -0.45336914,
"special": false,
"text": ":"
},
{
"id": 525,
"logprob": -2.6855469,
"id": 4829,
"logprob": -1.8408203,
"special": false,
"text": " '"
"text": " Error"
},
{
"id": 8516,
"logprob": -7.1601562,
"id": 297,
"logprob": -1.0556641,
"special": false,
"text": "None"
"text": " in"
},
{
"id": 4286,
"logprob": -2.4433594,
"id": 1243,
"logprob": 0.0,
"special": false,
"text": "'."
},
{
"id": 13,
"logprob": -0.06530762,
"special": false,
"text": "\n"
},
{
"id": 294,
"logprob": -7.953125,
"special": false,
"text": "as"
"text": " test"
}
]
},
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
"generated_text": "Test requestfailed: Error in test"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 12,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [
{
"id": 589,
@ -29,77 +29,365 @@
"tokens": [
{
"id": 2262,
"logprob": -0.7451172,
"logprob": -0.042999268,
"special": false,
"text": "():"
},
{
"id": 284,
"logprob": -0.21325684,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 5741,
"logprob": -5.734375,
"special": false,
"text": " logging"
},
{
"id": 32,
"id": 1459,
"logprob": 0.0,
"special": false,
"text": "."
"text": " print"
},
{
"id": 1338,
"logprob": -0.3232422,
"id": 440,
"logprob": 0.0,
"special": false,
"text": "info"
},
{
"id": 463,
"logprob": -1.0380859,
"special": false,
"text": "('"
"text": "(\""
},
{
"id": 8279,
"logprob": -0.8378906,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 30,
"logprob": -1.9501953,
"special": false,
"text": ","
},
{
"id": 10896,
"logprob": -1.3476562,
"logprob": -0.3659668,
"special": false,
"text": " World"
},
{
"id": 683,
"logprob": -1.796875,
"id": 657,
"logprob": -0.49804688,
"special": false,
"text": "')"
"text": "\")"
},
{
"id": 203,
"logprob": -0.9873047,
"logprob": -0.11279297,
"special": false,
"text": "\n"
},
{
"id": 0,
"logprob": -0.7495117,
"special": true,
"text": "<|endoftext|>"
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": -0.20141602,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": -0.051635742,
"special": false,
"text": "name"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": -0.16027832,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 11442,
"logprob": 0.0,
"special": false,
"text": " age"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 313,
"logprob": -0.6328125,
"special": false,
"text": " \""
},
{
"id": 313,
"logprob": -1.7011719,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 596,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 490,
"logprob": 0.0,
"special": false,
"text": "))"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
}
]
},
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [
{
"id": 0,
@ -14,65 +14,59 @@
"tokens": [
{
"id": 16017,
"logprob": -1.3505859,
"logprob": -0.30908203,
"special": false,
"text": " blue"
},
{
"id": 20495,
"logprob": -0.50439453,
"logprob": 0.0,
"special": false,
"text": " sky"
},
{
"id": 259,
"logprob": -1.2011719,
"logprob": -0.28271484,
"special": false,
"text": " "
},
{
"id": 15484,
"logprob": -2.8378906,
"logprob": -1.7929688,
"special": false,
"text": "appear"
},
{
"id": 345,
"logprob": -0.87597656,
"logprob": -0.8935547,
"special": false,
"text": "ed"
},
{
"id": 288,
"logprob": -1.8447266,
"id": 281,
"logprob": 0.0,
"special": false,
"text": " to"
"text": " in"
},
{
"id": 35622,
"logprob": -7.1445312,
"id": 287,
"logprob": 0.0,
"special": false,
"text": " cloud"
"text": " the"
},
{
"id": 263,
"logprob": -1.2929688,
"id": 20495,
"logprob": -0.32299805,
"special": false,
"text": "s"
"text": " sky"
},
{
"id": 14701,
"logprob": -3.0761719,
"special": false,
"text": " above"
},
{
"id": 751,
"logprob": -4.4375,
"special": false,
"text": " all"
"id": 1,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
}

View File

@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 5
assert response == response_snapshot

View File

@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
)
assert response.details.generated_tokens == 12
assert response.details.generated_tokens == 60
assert response == response_snapshot

View File

@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 9
assert response == response_snapshot

View File

@ -631,7 +631,9 @@ class FlashCausalLM(Model):
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]

View File

@ -231,7 +231,6 @@ class FlashSantacoderSharded(FlashSantacoder):
device=device,
rank=rank,
world_size=world_size,
decode_buffer=1,
)
@staticmethod

View File

@ -1,14 +1,10 @@
import concurrent
import time
import datetime
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file
from safetensors import safe_open
from typing import Dict, List

View File

@ -2,7 +2,7 @@ import math
import torch
from functools import lru_cache
from typing import Optional, List
from typing import Optional, List, Dict, Union
from transformers import (
LogitsWarper,
@ -45,9 +45,11 @@ class StaticWarper:
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
local_scores = self.static_scores
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
@ -309,3 +311,32 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
def filter(self, indices):
self.mass = self.mass[indices]
return self
class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
):
self.processors = processors
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
for i, processor in self.processors.items():
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
return scores
def filter(self, indices):
new_processors = {}
for i, idx in enumerate(indices):
if idx in self.processors:
new_processors[i] = self.processors[idx]
self.processors = new_processors
return self

View File

@ -18,6 +18,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper,
)
@ -168,7 +169,15 @@ class HeterogeneousNextTokenChooser:
warpers = LogitsProcessorList()
if any(watermark):
raise NotImplementedError("Watermarking not implemented")
warpers.append(
HeterogeneousProcessorWrapper(
{
i: WatermarkLogitsProcessor(device=device)
for i, do_watermark in enumerate(watermark)
if do_watermark
}
)
)
if any([x != 1.0 for x in repetition_penalty]):
warpers.append(