From f94fc831f46f4116a1d70f2aef4094fdbd39ffed Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 10 Feb 2024 03:11:59 +0000 Subject: [PATCH] feat: add grammar tests and typo tweaks --- clients/python/text_generation/client.py | 4 + clients/python/text_generation/types.py | 4 +- docs/source/basic_tutorials/launcher.md | 2 +- integration-tests/conftest.py | 11 +- .../test_flash_llama_grammar.json | 89 ++++ .../test_flash_llama_grammar_json.json | 274 ++++++++++ .../test_flash_llama_grammar_load.json | 478 ++++++++++++++++++ .../test_flash_llama_grammar_regex.json | 109 ++++ ...sh_llama_grammar_single_load_instance.json | 73 +++ .../models/test_grammar_llama.py | 132 +++++ router/src/validation.rs | 2 - 11 files changed, 1172 insertions(+), 6 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json create mode 100644 integration-tests/models/test_grammar_llama.py diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0bf80f8c..06c29ce6 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -76,6 +76,7 @@ class Client: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: str = "" ) -> Response: """ Given a prompt, generate the following text @@ -138,6 +139,7 @@ class Client: watermark=watermark, decoder_input_details=decoder_input_details, top_n_tokens=top_n_tokens, + grammar=grammar ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -326,6 +328,7 @@ class AsyncClient: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: str = "" ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -388,6 +391,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index aa02d8d8..3369a5fd 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -41,6 +41,8 @@ class Parameters(BaseModel): decoder_input_details: bool = False # Return the N most likely tokens at each step top_n_tokens: Optional[int] = None + # grammar to use for generation + grammar: Optional[str] = None @validator("best_of") def valid_best_of(cls, field_value, values): @@ -157,7 +159,7 @@ class Token(BaseModel): # Token text text: str # Logprob - logprob: float + logprob: Optional[float] = None # Is the token a special token # Can be used to ignore tokens when concatenating special: bool diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index bb79e1ef..be31a7a4 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -396,4 +396,4 @@ Options: -V, --version Print version -``` \ No newline at end of file +``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index efeda08d..86944c8e 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -370,11 +370,18 @@ def launcher(event_loop): @pytest.fixture(scope="module") def generate_load(): async def generate_load_inner( - client: AsyncClient, prompt: str, max_new_tokens: int, n: int + client: AsyncClient, + prompt: str, + max_new_tokens: int, + n: int, + **kwargs, ) -> List[Response]: futures = [ client.generate( - prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + prompt, + max_new_tokens=max_new_tokens, + decoder_input_details=True, + **kwargs, ) for _ in range(n) ] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json new file mode 100644 index 00000000..0e87f59e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -13.90625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.328125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0566406, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.5253906, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.7578125, + "special": false, + "text": "I" + }, + { + "id": 4966, + "logprob": -1.9033203, + "special": false, + "text": " hope" + }, + { + "id": 445, + "logprob": -0.5019531, + "special": false, + "text": " this" + }, + { + "id": 6911, + "logprob": -0.21264648, + "special": false, + "text": " helps" + }, + { + "id": 29991, + "logprob": -0.5991211, + "special": false, + "text": "!" + }, + { + "id": 2803, + "logprob": -0.37475586, + "special": false, + "text": " Let" + }, + { + "id": 592, + "logprob": -0.018463135, + "special": false, + "text": " me" + }, + { + "id": 1073, + "logprob": -0.0008597374, + "special": false, + "text": " know" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI hope this helps! Let me know" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json new file mode 100644 index 00000000..d0e017c1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2265625, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.1484375, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5683594, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0546875, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3125, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6665039, + "text": "ats" + }, + { + "id": 29889, + "logprob": -1.0009766, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.15002441, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13549805, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017562866, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.0008444786, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0053634644, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13537598, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8886719, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16381836, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.02017212, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.0013923645, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0067749023, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.11407471, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.0040626526, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.0032863617, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.20507812, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.0068740845, + "special": false, + "text": "h" + }, + { + "id": 20838, + "logprob": -0.19714355, + "special": false, + "text": "obb" + }, + { + "id": 29891, + "logprob": -2.2649765e-06, + "special": false, + "text": "y" + }, + { + "id": 4710, + "logprob": -0.31860352, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.09375, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.02053833, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.59814453, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.5732422, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.006198883, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.45703125, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0002872944, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.002117157, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.089416504, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.021835327, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json new file mode 100644 index 00000000..b7b26a2c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -0,0 +1,478 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.03125, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04244995, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4863281, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32714844, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.2376709, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.01008606, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64160156, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.5, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46557617, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5341797, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022907257, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + } +] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json new file mode 100644 index 00000000..71dad72c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json @@ -0,0 +1,109 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 806, + "logprob": -11.90625, + "text": "Wh" + }, + { + "id": 1446, + "logprob": -3.6660156, + "text": "ats" + }, + { + "id": 2921, + "logprob": -7.8203125, + "text": "Go" + }, + { + "id": 468, + "logprob": -8.0625, + "text": "og" + }, + { + "id": 793, + "logprob": -2.1816406, + "text": "les" + }, + { + "id": 16332, + "logprob": -9.71875, + "text": "DNS" + } + ], + "seed": null, + "tokens": [ + { + "id": 29946, + "logprob": -1.4736328, + "special": false, + "text": "4" + }, + { + "id": 29906, + "logprob": -0.91845703, + "special": false, + "text": "2" + }, + { + "id": 29889, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -1.1386719, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -1.4638672, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.40771484, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -0.17553711, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.20776367, + "special": false, + "text": "1" + }, + { + "id": 29900, + "logprob": -1.5546875, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": -1.3681641, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": "42.1.1.101" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json new file mode 100644 index 00000000..7ffb17cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33666992, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.009979248, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.53759766, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.0008878708, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.002275467, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" +} diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py new file mode 100644 index 00000000..7492718b --- /dev/null +++ b/integration-tests/models/test_grammar_llama.py @@ -0,0 +1,132 @@ +import pytest +import json + + +@pytest.fixture(scope="module") +def flash_llama_grammar_handle(launcher): + with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar(flash_llama_grammar_handle): + await flash_llama_grammar_handle.health(300) + return flash_llama_grammar_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=0, + grammar="((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "42.1.1.101" + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar=json.dumps( + { + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "type": "object", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": {"description": "The person'''s hobby.", "type": "string"}, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_load( + flash_llama_grammar, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_grammar, + "name: david. email: ", + max_new_tokens=10, + n=4, + stop_sequences=[".com"], + seed=0, + grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + ) + + assert len(responses) == 4 + + expected = "123456@gmail.com" + + for response in responses: + assert response.generated_text == expected + + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +# this is the same as the above test, but only fires off a single request +# this is only to ensure that the parallel and single inference produce the same result +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_single_load_instance( + flash_llama_grammar, generate_load, response_snapshot +): + response = await flash_llama_grammar.generate( + "name: david. email: ", + max_new_tokens=10, + stop_sequences=[".com"], + seed=0, + grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + ) + + # assert response.details.generated_tokens == 30 + assert response.generated_text == "123456@gmail.com" + + assert response == response_snapshot diff --git a/router/src/validation.rs b/router/src/validation.rs index 2959459d..4c38db68 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -293,8 +293,6 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; - // initialize the grammar parameter - let grammar = grammar; // init the start state of the grammar let fsm_grammar_state = 0;