From 141e67a1bf1528fa1d92be689b263a45bf7f6f19 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 29 Feb 2024 18:32:23 +0000 Subject: [PATCH] fix: correctly index into mask when applying grammar --- .../test_non_flash_llama_grammar_json.json | 274 ++++++++++++++++++ .../models/test_grammar_llama.py | 65 +++++ .../utils/logits_process.py | 2 +- 3 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json new file mode 100644 index 00000000..d7fb620d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2324219, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.140625, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5742188, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0507812, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3164062, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6694336, + "text": "ats" + }, + { + "id": 29889, + "logprob": -0.9995117, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.14916992, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13598633, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017669678, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.00085639954, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0054016113, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13549805, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8852539, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16394043, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.08862305, + "special": false, + "text": "h" + }, + { + "id": 711, + "logprob": -0.66259766, + "special": false, + "text": "ob" + }, + { + "id": 1609, + "logprob": -5.51939e-05, + "special": false, + "text": "by" + }, + { + "id": 4710, + "logprob": -0.23120117, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.3730469, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.032104492, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.22021484, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.06726074, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.003501892, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0045661926, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.12512207, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.009552002, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.00042438507, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.11651611, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.29736328, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.003030777, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.3774414, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0003130436, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.0021514893, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.071899414, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.018997192, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" +} diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index 585d0656..59e9774b 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -144,3 +144,68 @@ async def test_flash_llama_grammar_single_load_instance( assert response.generated_text == "123456@gmail.com" assert response == response_snapshot + + +@pytest.fixture(scope="module") +def non_flash_llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def non_flash_llama_grammar(non_flash_llama_grammar_handle): + await non_flash_llama_grammar_handle.health(300) + return non_flash_llama_grammar_handle.client + + +@pytest.mark.asyncio +async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot): + response = await non_flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Json, + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' + ) + assert response == response_snapshot diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 40f31ce2..cd7efec8 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -491,7 +491,7 @@ class GrammarLogitProcessor(LogitsProcessor): return logits allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) mask = torch.full_like(logits, -math.inf) - mask[allowed_tokens] = 0 + mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores