fix: correctly index into mask when applying grammar

This commit is contained in:
drbh 2024-02-29 18:32:23 +00:00
parent 3dd7da2198
commit 141e67a1bf
3 changed files with 340 additions and 1 deletions

View File

@ -0,0 +1,274 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 30,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 5235,
"logprob": -10.0625,
"text": "info"
},
{
"id": 29901,
"logprob": -3.2324219,
"text": ":"
},
{
"id": 13260,
"logprob": -10.625,
"text": "dav"
},
{
"id": 333,
"logprob": -0.08276367,
"text": "id"
},
{
"id": 8753,
"logprob": -7.5273438,
"text": "hol"
},
{
"id": 17559,
"logprob": -3.8476562,
"text": "tz"
},
{
"id": 763,
"logprob": -10.140625,
"text": "like"
},
{
"id": 10697,
"logprob": -10.1953125,
"text": "trees"
},
{
"id": 322,
"logprob": -2.5742188,
"text": "and"
},
{
"id": 756,
"logprob": -7.4882812,
"text": "has"
},
{
"id": 1023,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -0.9995117,
"text": "."
},
{
"id": 29871,
"logprob": -4.2421875,
"text": ""
}
],
"seed": null,
"tokens": [
{
"id": 6377,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.08862305,
"special": false,
"text": "h"
},
{
"id": 711,
"logprob": -0.66259766,
"special": false,
"text": "ob"
},
{
"id": 1609,
"logprob": -5.51939e-05,
"special": false,
"text": "by"
},
{
"id": 4710,
"logprob": -0.23120117,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.3730469,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.032104492,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.22021484,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.06726074,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.003501892,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0045661926,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.12512207,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.009552002,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.00042438507,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.11651611,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.29736328,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.003030777,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.3774414,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0003130436,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021514893,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.071899414,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.018997192,
"special": false,
"text": "}"
},
{
"id": 2,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
}

View File

@ -144,3 +144,68 @@ async def test_flash_llama_grammar_single_load_instance(
assert response.generated_text == "123456@gmail.com" assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot assert response == response_snapshot
@pytest.fixture(scope="module")
def non_flash_llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
await non_flash_llama_grammar_handle.health(300)
return non_flash_llama_grammar_handle.client
@pytest.mark.asyncio
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
response = await non_flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Json,
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
)
assert response == response_snapshot

View File

@ -491,7 +491,7 @@ class GrammarLogitProcessor(LogitsProcessor):
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = 0 mask[:, allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores