mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: correctly index into mask when applying grammar
This commit is contained in:
parent
3dd7da2198
commit
141e67a1bf
@ -0,0 +1,274 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 30,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5235,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "info"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.2324219,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -10.625,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.08276367,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8753,
|
||||||
|
"logprob": -7.5273438,
|
||||||
|
"text": "hol"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17559,
|
||||||
|
"logprob": -3.8476562,
|
||||||
|
"text": "tz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 763,
|
||||||
|
"logprob": -10.140625,
|
||||||
|
"text": "like"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10697,
|
||||||
|
"logprob": -10.1953125,
|
||||||
|
"text": "trees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 322,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"text": "and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 756,
|
||||||
|
"logprob": -7.4882812,
|
||||||
|
"text": "has"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1023,
|
||||||
|
"logprob": -5.0507812,
|
||||||
|
"text": "two"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 274,
|
||||||
|
"logprob": -5.3164062,
|
||||||
|
"text": "c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.6694336,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.9995117,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29871,
|
||||||
|
"logprob": -4.2421875,
|
||||||
|
"text": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 6377,
|
||||||
|
"logprob": -0.14916992,
|
||||||
|
"special": false,
|
||||||
|
"text": "{\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29888,
|
||||||
|
"logprob": -0.13598633,
|
||||||
|
"special": false,
|
||||||
|
"text": "f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12935,
|
||||||
|
"logprob": -0.017669678,
|
||||||
|
"special": false,
|
||||||
|
"text": "irs"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29873,
|
||||||
|
"logprob": -0.00085639954,
|
||||||
|
"special": false,
|
||||||
|
"text": "t"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.0054016113,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.13549805,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19504,
|
||||||
|
"logprob": -0.8852539,
|
||||||
|
"special": false,
|
||||||
|
"text": "David"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.16394043,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29882,
|
||||||
|
"logprob": -0.08862305,
|
||||||
|
"special": false,
|
||||||
|
"text": "h"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 711,
|
||||||
|
"logprob": -0.66259766,
|
||||||
|
"special": false,
|
||||||
|
"text": "ob"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1609,
|
||||||
|
"logprob": -5.51939e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "by"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.23120117,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29911,
|
||||||
|
"logprob": -2.3730469,
|
||||||
|
"special": false,
|
||||||
|
"text": "T"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11003,
|
||||||
|
"logprob": -0.032104492,
|
||||||
|
"special": false,
|
||||||
|
"text": "rees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.22021484,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4230,
|
||||||
|
"logprob": -0.06726074,
|
||||||
|
"special": false,
|
||||||
|
"text": "last"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.003501892,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.0045661926,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29950,
|
||||||
|
"logprob": -0.12512207,
|
||||||
|
"special": false,
|
||||||
|
"text": "H"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14339,
|
||||||
|
"logprob": -0.009552002,
|
||||||
|
"special": false,
|
||||||
|
"text": "olt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29920,
|
||||||
|
"logprob": -0.00042438507,
|
||||||
|
"special": false,
|
||||||
|
"text": "z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.11651611,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29876,
|
||||||
|
"logprob": -0.29736328,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 398,
|
||||||
|
"logprob": -0.003030777,
|
||||||
|
"special": false,
|
||||||
|
"text": "um"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29907,
|
||||||
|
"logprob": -0.3774414,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.0003130436,
|
||||||
|
"special": false,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1115,
|
||||||
|
"logprob": -0.0021514893,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.071899414,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29913,
|
||||||
|
"logprob": -0.018997192,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
|
||||||
|
}
|
@ -144,3 +144,68 @@ async def test_flash_llama_grammar_single_load_instance(
|
|||||||
assert response.generated_text == "123456@gmail.com"
|
assert response.generated_text == "123456@gmail.com"
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def non_flash_llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
||||||
|
await non_flash_llama_grammar_handle.health(300)
|
||||||
|
return non_flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||||
|
response = await non_flash_llama_grammar.generate(
|
||||||
|
"info: david holtz like trees and has two cats. ",
|
||||||
|
max_new_tokens=100,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Json,
|
||||||
|
"value": json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"$id": "https://example.com/person.schema.json",
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"title": "Person",
|
||||||
|
"properties": {
|
||||||
|
"firstName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s first name.",
|
||||||
|
},
|
||||||
|
"lastName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s last name.",
|
||||||
|
},
|
||||||
|
"hobby": {
|
||||||
|
"description": "The person'''s hobby.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"numCats": {
|
||||||
|
"description": "The number of cats the person has.",
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 30
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
@ -491,7 +491,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
return logits
|
return logits
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
mask = torch.full_like(logits, -math.inf)
|
mask = torch.full_like(logits, -math.inf)
|
||||||
mask[allowed_tokens] = 0
|
mask[:, allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
return biased_scores
|
return biased_scores
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user