mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add watermarking
This commit is contained in:
parent
b9ad3acc4e
commit
c59fb353a0
@ -34,65 +34,65 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 408,
|
||||
"logprob": -1.9267578,
|
||||
"logprob": -0.07891846,
|
||||
"special": false,
|
||||
"text": " que"
|
||||
},
|
||||
{
|
||||
"id": 20288,
|
||||
"logprob": -2.9257812,
|
||||
"special": false,
|
||||
"text": " l'on"
|
||||
},
|
||||
{
|
||||
"id": 22255,
|
||||
"logprob": -2.8964844,
|
||||
"special": false,
|
||||
"text": " trouve"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -1.1083984,
|
||||
"special": false,
|
||||
"text": " une"
|
||||
},
|
||||
{
|
||||
"id": 187079,
|
||||
"logprob": -7.796875,
|
||||
"special": false,
|
||||
"text": " posture"
|
||||
},
|
||||
{
|
||||
"id": 501,
|
||||
"logprob": -5.390625,
|
||||
"special": false,
|
||||
"text": " par"
|
||||
},
|
||||
{
|
||||
"id": 8741,
|
||||
"logprob": -0.34936523,
|
||||
"special": false,
|
||||
"text": " rapport"
|
||||
},
|
||||
{
|
||||
"id": 693,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " à"
|
||||
},
|
||||
{
|
||||
"id": 366,
|
||||
"logprob": -2.3378906,
|
||||
"logprob": -1.2939453,
|
||||
"special": false,
|
||||
"text": " la"
|
||||
},
|
||||
{
|
||||
"id": 36503,
|
||||
"logprob": -3.6640625,
|
||||
"id": 8769,
|
||||
"logprob": -0.3708496,
|
||||
"special": false,
|
||||
"text": " pratique"
|
||||
"text": " personne"
|
||||
},
|
||||
{
|
||||
"id": 1479,
|
||||
"logprob": -2.2871094,
|
||||
"special": false,
|
||||
"text": " qui"
|
||||
},
|
||||
{
|
||||
"id": 2997,
|
||||
"logprob": -0.8671875,
|
||||
"special": false,
|
||||
"text": " vous"
|
||||
},
|
||||
{
|
||||
"id": 35977,
|
||||
"logprob": -1.5097656,
|
||||
"special": false,
|
||||
"text": " suit"
|
||||
},
|
||||
{
|
||||
"id": 21558,
|
||||
"logprob": -0.07891846,
|
||||
"special": false,
|
||||
"text": " ait"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -0.12695312,
|
||||
"special": false,
|
||||
"text": " un"
|
||||
},
|
||||
{
|
||||
"id": 78606,
|
||||
"logprob": -2.21875,
|
||||
"special": false,
|
||||
"text": " profil"
|
||||
},
|
||||
{
|
||||
"id": 3899,
|
||||
"logprob": -1.3535156,
|
||||
"special": false,
|
||||
"text": " bien"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
|
||||
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 5,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
@ -24,65 +24,35 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -3.3085938,
|
||||
"logprob": -2.5683594,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 363,
|
||||
"logprob": -3.984375,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 5641,
|
||||
"logprob": -6.53125,
|
||||
"special": false,
|
||||
"text": " IP"
|
||||
},
|
||||
{
|
||||
"id": 16428,
|
||||
"logprob": -3.1835938,
|
||||
"special": false,
|
||||
"text": " Address"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -1.2324219,
|
||||
"logprob": -0.45336914,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 525,
|
||||
"logprob": -2.6855469,
|
||||
"id": 4829,
|
||||
"logprob": -1.8408203,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
"text": " Error"
|
||||
},
|
||||
{
|
||||
"id": 8516,
|
||||
"logprob": -7.1601562,
|
||||
"id": 297,
|
||||
"logprob": -1.0556641,
|
||||
"special": false,
|
||||
"text": "None"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 4286,
|
||||
"logprob": -2.4433594,
|
||||
"id": 1243,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "'."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.06530762,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 294,
|
||||
"logprob": -7.953125,
|
||||
"special": false,
|
||||
"text": "as"
|
||||
"text": " test"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
|
||||
"generated_text": "Test requestfailed: Error in test"
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 12,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 60,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
@ -29,77 +29,365 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2262,
|
||||
"logprob": -0.7451172,
|
||||
"logprob": -0.042999268,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.21325684,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 5741,
|
||||
"logprob": -5.734375,
|
||||
"special": false,
|
||||
"text": " logging"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 1338,
|
||||
"logprob": -0.3232422,
|
||||
"id": 440,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "info"
|
||||
},
|
||||
{
|
||||
"id": 463,
|
||||
"logprob": -1.0380859,
|
||||
"special": false,
|
||||
"text": "('"
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": -0.8378906,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.9501953,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 10896,
|
||||
"logprob": -1.3476562,
|
||||
"logprob": -0.3659668,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 683,
|
||||
"logprob": -1.796875,
|
||||
"id": 657,
|
||||
"logprob": -0.49804688,
|
||||
"special": false,
|
||||
"text": "')"
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": -0.9873047,
|
||||
"logprob": -0.11279297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -0.7495117,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": -0.20141602,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7656,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": -0.051635742,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 440,
|
||||
"logprob": -0.16027832,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 636,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7656,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 381,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 11442,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " age"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 440,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 636,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": -0.6328125,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": -1.7011719,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 596,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " str"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 381,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 490,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "))"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
|
||||
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 9,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
@ -14,65 +14,59 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 16017,
|
||||
"logprob": -1.3505859,
|
||||
"logprob": -0.30908203,
|
||||
"special": false,
|
||||
"text": " blue"
|
||||
},
|
||||
{
|
||||
"id": 20495,
|
||||
"logprob": -0.50439453,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " sky"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -1.2011719,
|
||||
"logprob": -0.28271484,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 15484,
|
||||
"logprob": -2.8378906,
|
||||
"logprob": -1.7929688,
|
||||
"special": false,
|
||||
"text": "appear"
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": -0.87597656,
|
||||
"logprob": -0.8935547,
|
||||
"special": false,
|
||||
"text": "ed"
|
||||
},
|
||||
{
|
||||
"id": 288,
|
||||
"logprob": -1.8447266,
|
||||
"id": 281,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 35622,
|
||||
"logprob": -7.1445312,
|
||||
"id": 287,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.2929688,
|
||||
"id": 20495,
|
||||
"logprob": -0.32299805,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
"text": " sky"
|
||||
},
|
||||
{
|
||||
"id": 14701,
|
||||
"logprob": -3.0761719,
|
||||
"special": false,
|
||||
"text": " above"
|
||||
},
|
||||
{
|
||||
"id": 751,
|
||||
"logprob": -4.4375,
|
||||
"special": false,
|
||||
"text": " all"
|
||||
"id": 1,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
|
||||
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.details.generated_tokens == 5
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
||||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 12
|
||||
assert response.details.generated_tokens == 60
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.details.generated_tokens == 9
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -631,7 +631,9 @@ class FlashCausalLM(Model):
|
||||
] = batch.input_ids[start_index + 1 : end_index]
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : end_index
|
||||
]
|
||||
|
||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||
|
||||
|
@ -231,7 +231,6 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,14 +1,10 @@
|
||||
import concurrent
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
from typing import Dict, List
|
||||
|
||||
|
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
@ -45,9 +45,11 @@ class StaticWarper:
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
local_scores = warper(None, local_scores)
|
||||
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
@ -309,3 +311,32 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
def filter(self, indices):
|
||||
self.mass = self.mass[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
r"""
|
||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||
Args:
|
||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||
):
|
||||
self.processors = processors
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
for i, processor in self.processors.items():
|
||||
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
new_processors = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.processors:
|
||||
new_processors[i] = self.processors[idx]
|
||||
|
||||
self.processors = new_processors
|
||||
return self
|
||||
|
@ -18,6 +18,7 @@ from text_generation_server.utils.logits_process import (
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
)
|
||||
|
||||
|
||||
@ -168,7 +169,15 @@ class HeterogeneousNextTokenChooser:
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
if any(watermark):
|
||||
raise NotImplementedError("Watermarking not implemented")
|
||||
warpers.append(
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
i: WatermarkLogitsProcessor(device=device)
|
||||
for i, do_watermark in enumerate(watermark)
|
||||
if do_watermark
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in repetition_penalty]):
|
||||
warpers.append(
|
||||
|
Loading…
Reference in New Issue
Block a user