mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: split flash and non flash grammar tests
This commit is contained in:
parent
141e67a1bf
commit
06fd5affa0
@ -1,4 +1,123 @@
|
|||||||
[
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1024,
|
||||||
|
"logprob": -10.578125,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.0332031,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -9.171875,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.04257202,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -2.4785156,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4876,
|
||||||
|
"logprob": -10.7890625,
|
||||||
|
"text": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -0.32495117,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -9.4921875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7709961,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.33740234,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.00995636,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64208984,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.4970703,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46533203,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5336914,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.00088739395,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.0022735596,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
@ -355,124 +474,5 @@
|
|||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "123456@gmail.com"
|
"generated_text": "123456@gmail.com"
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1024,
|
|
||||||
"logprob": -10.578125,
|
|
||||||
"text": "name"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29901,
|
|
||||||
"logprob": -3.0332031,
|
|
||||||
"text": ":"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13260,
|
|
||||||
"logprob": -9.171875,
|
|
||||||
"text": "dav"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 333,
|
|
||||||
"logprob": -0.04257202,
|
|
||||||
"text": "id"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29889,
|
|
||||||
"logprob": -2.4785156,
|
|
||||||
"text": "."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4876,
|
|
||||||
"logprob": -10.7890625,
|
|
||||||
"text": "email"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29901,
|
|
||||||
"logprob": -0.32495117,
|
|
||||||
"text": ":"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"logprob": -9.4921875,
|
|
||||||
"text": " "
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": -0.7709961,
|
|
||||||
"special": false,
|
|
||||||
"text": "1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29906,
|
|
||||||
"logprob": -0.33740234,
|
|
||||||
"special": false,
|
|
||||||
"text": "2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29941,
|
|
||||||
"logprob": -0.00995636,
|
|
||||||
"special": false,
|
|
||||||
"text": "3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29946,
|
|
||||||
"logprob": -0.64208984,
|
|
||||||
"special": false,
|
|
||||||
"text": "4"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29945,
|
|
||||||
"logprob": -0.4970703,
|
|
||||||
"special": false,
|
|
||||||
"text": "5"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29953,
|
|
||||||
"logprob": -0.46533203,
|
|
||||||
"special": false,
|
|
||||||
"text": "6"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29992,
|
|
||||||
"logprob": -0.5336914,
|
|
||||||
"special": false,
|
|
||||||
"text": "@"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 21980,
|
|
||||||
"logprob": -0.5361328,
|
|
||||||
"special": false,
|
|
||||||
"text": "gmail"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29889,
|
|
||||||
"logprob": -0.00088739395,
|
|
||||||
"special": false,
|
|
||||||
"text": "."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 510,
|
|
||||||
"logprob": -0.0022735596,
|
|
||||||
"special": false,
|
|
||||||
"text": "com"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "123456@gmail.com"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
@ -0,0 +1,274 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 30,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5235,
|
||||||
|
"logprob": -10.061389,
|
||||||
|
"text": "info"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.2349052,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -10.626516,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.08372568,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8753,
|
||||||
|
"logprob": -7.5279083,
|
||||||
|
"text": "hol"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17559,
|
||||||
|
"logprob": -3.8427715,
|
||||||
|
"text": "tz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 763,
|
||||||
|
"logprob": -10.143592,
|
||||||
|
"text": "like"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10697,
|
||||||
|
"logprob": -10.200588,
|
||||||
|
"text": "trees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 322,
|
||||||
|
"logprob": -2.5744739,
|
||||||
|
"text": "and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 756,
|
||||||
|
"logprob": -7.4822097,
|
||||||
|
"text": "has"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1023,
|
||||||
|
"logprob": -5.043413,
|
||||||
|
"text": "two"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 274,
|
||||||
|
"logprob": -5.326814,
|
||||||
|
"text": "c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.67299384,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.999048,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29871,
|
||||||
|
"logprob": -4.2404404,
|
||||||
|
"text": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 6377,
|
||||||
|
"logprob": -0.1497998,
|
||||||
|
"special": false,
|
||||||
|
"text": "{\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29888,
|
||||||
|
"logprob": -0.1359236,
|
||||||
|
"special": false,
|
||||||
|
"text": "f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12935,
|
||||||
|
"logprob": -0.01771052,
|
||||||
|
"special": false,
|
||||||
|
"text": "irs"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29873,
|
||||||
|
"logprob": -0.00084543246,
|
||||||
|
"special": false,
|
||||||
|
"text": "t"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.0053624124,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.13352497,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19504,
|
||||||
|
"logprob": -0.8816582,
|
||||||
|
"special": false,
|
||||||
|
"text": "David"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.1636697,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29882,
|
||||||
|
"logprob": -0.08828322,
|
||||||
|
"special": false,
|
||||||
|
"text": "h"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 711,
|
||||||
|
"logprob": -0.66238964,
|
||||||
|
"special": false,
|
||||||
|
"text": "ob"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1609,
|
||||||
|
"logprob": -5.566919e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "by"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.2296004,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29911,
|
||||||
|
"logprob": -2.3745353,
|
||||||
|
"special": false,
|
||||||
|
"text": "T"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11003,
|
||||||
|
"logprob": -0.032119535,
|
||||||
|
"special": false,
|
||||||
|
"text": "rees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.22055298,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4230,
|
||||||
|
"logprob": -0.067228675,
|
||||||
|
"special": false,
|
||||||
|
"text": "last"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.0035023084,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.004494921,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29950,
|
||||||
|
"logprob": -0.12524654,
|
||||||
|
"special": false,
|
||||||
|
"text": "H"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14339,
|
||||||
|
"logprob": -0.009601957,
|
||||||
|
"special": false,
|
||||||
|
"text": "olt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29920,
|
||||||
|
"logprob": -0.00041619223,
|
||||||
|
"special": false,
|
||||||
|
"text": "z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.116980776,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29876,
|
||||||
|
"logprob": -0.2994127,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 398,
|
||||||
|
"logprob": -0.0030563807,
|
||||||
|
"special": false,
|
||||||
|
"text": "um"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29907,
|
||||||
|
"logprob": -0.37736154,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.00031073033,
|
||||||
|
"special": false,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1115,
|
||||||
|
"logprob": -0.0021851014,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.07180126,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29913,
|
||||||
|
"logprob": -0.018707855,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
|
||||||
|
}
|
211
integration-tests/models/test_flash_grammar_llama.py
Normal file
211
integration-tests/models/test_flash_grammar_llama.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||||
|
await flash_llama_grammar_handle.health(300)
|
||||||
|
return flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"Whats Googles DNS",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "42.1.1.101"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"info: david holtz like trees and has two cats. ",
|
||||||
|
max_new_tokens=100,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Json, # "json"
|
||||||
|
"value": json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"$id": "https://example.com/person.schema.json",
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"title": "Person",
|
||||||
|
"properties": {
|
||||||
|
"firstName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s first name.",
|
||||||
|
},
|
||||||
|
"lastName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s last name.",
|
||||||
|
},
|
||||||
|
"hobby": {
|
||||||
|
"description": "The person'''s hobby.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"numCats": {
|
||||||
|
"description": "The number of cats the person has.",
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 30
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_llama_grammar_load(
|
||||||
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_grammar,
|
||||||
|
"name: david. email: ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
stop_sequences=[".com"],
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
expected = "123456@gmail.com"
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
assert response.generated_text == expected
|
||||||
|
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# this is the same as the above test, but only fires off a single request
|
||||||
|
# this is only to ensure that the parallel and single inference produce the same result
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_llama_grammar_single_load_instance(
|
||||||
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"name: david. email: ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
stop_sequences=[".com"],
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert response.details.generated_tokens == 30
|
||||||
|
assert response.generated_text == "123456@gmail.com"
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def non_flash_llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
||||||
|
await non_flash_llama_grammar_handle.health(300)
|
||||||
|
return non_flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||||
|
response = await non_flash_llama_grammar.generate(
|
||||||
|
"info: david holtz like trees and has two cats. ",
|
||||||
|
max_new_tokens=100,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Json,
|
||||||
|
"value": json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"$id": "https://example.com/person.schema.json",
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"title": "Person",
|
||||||
|
"properties": {
|
||||||
|
"firstName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s first name.",
|
||||||
|
},
|
||||||
|
"lastName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s last name.",
|
||||||
|
},
|
||||||
|
"hobby": {
|
||||||
|
"description": "The person'''s hobby.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"numCats": {
|
||||||
|
"description": "The number of cats the person has.",
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 30
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -4,148 +4,6 @@ import json
|
|||||||
from text_generation.types import GrammarType
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def flash_llama_grammar_handle(launcher):
|
|
||||||
with launcher(
|
|
||||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
async def flash_llama_grammar(flash_llama_grammar_handle):
|
|
||||||
await flash_llama_grammar_handle.health(300)
|
|
||||||
return flash_llama_grammar_handle.client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
|
||||||
response = await flash_llama_grammar.generate(
|
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
|
||||||
response = await flash_llama_grammar.generate(
|
|
||||||
"Whats Googles DNS",
|
|
||||||
max_new_tokens=10,
|
|
||||||
decoder_input_details=True,
|
|
||||||
seed=0,
|
|
||||||
grammar={
|
|
||||||
"type": GrammarType.Regex, # "regex"
|
|
||||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response.generated_text == "42.1.1.101"
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|
||||||
response = await flash_llama_grammar.generate(
|
|
||||||
"info: david holtz like trees and has two cats. ",
|
|
||||||
max_new_tokens=100,
|
|
||||||
decoder_input_details=True,
|
|
||||||
seed=0,
|
|
||||||
grammar={
|
|
||||||
"type": GrammarType.Json, # "json"
|
|
||||||
"value": json.dumps(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"$id": "https://example.com/person.schema.json",
|
|
||||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
||||||
"title": "Person",
|
|
||||||
"properties": {
|
|
||||||
"firstName": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The person'''s first name.",
|
|
||||||
},
|
|
||||||
"lastName": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The person'''s last name.",
|
|
||||||
},
|
|
||||||
"hobby": {
|
|
||||||
"description": "The person'''s hobby.",
|
|
||||||
"type": "string",
|
|
||||||
},
|
|
||||||
"numCats": {
|
|
||||||
"description": "The number of cats the person has.",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 30
|
|
||||||
assert (
|
|
||||||
response.generated_text
|
|
||||||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
|
||||||
)
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flash_llama_grammar_load(
|
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
|
||||||
):
|
|
||||||
responses = await generate_load(
|
|
||||||
flash_llama_grammar,
|
|
||||||
"name: david. email: ",
|
|
||||||
max_new_tokens=10,
|
|
||||||
n=4,
|
|
||||||
stop_sequences=[".com"],
|
|
||||||
seed=0,
|
|
||||||
grammar={
|
|
||||||
"type": GrammarType.Regex, # "regex"
|
|
||||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(responses) == 4
|
|
||||||
|
|
||||||
expected = "123456@gmail.com"
|
|
||||||
|
|
||||||
for response in responses:
|
|
||||||
assert response.generated_text == expected
|
|
||||||
|
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
# this is the same as the above test, but only fires off a single request
|
|
||||||
# this is only to ensure that the parallel and single inference produce the same result
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flash_llama_grammar_single_load_instance(
|
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
|
||||||
):
|
|
||||||
response = await flash_llama_grammar.generate(
|
|
||||||
"name: david. email: ",
|
|
||||||
max_new_tokens=10,
|
|
||||||
stop_sequences=[".com"],
|
|
||||||
seed=0,
|
|
||||||
grammar={
|
|
||||||
"type": GrammarType.Regex, # "regex"
|
|
||||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# assert response.details.generated_tokens == 30
|
|
||||||
assert response.generated_text == "123456@gmail.com"
|
|
||||||
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def non_flash_llama_grammar_handle(launcher):
|
def non_flash_llama_grammar_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
|
Loading…
Reference in New Issue
Block a user