mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: split flash and non flash grammar tests
This commit is contained in:
parent
141e67a1bf
commit
06fd5affa0
@ -1,4 +1,123 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1024,
|
||||
"logprob": -10.578125,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.0332031,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -9.171875,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.04257202,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -2.4785156,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 4876,
|
||||
"logprob": -10.7890625,
|
||||
"text": "email"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.32495117,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -9.4921875,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": -0.7709961,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29941,
|
||||
"logprob": -0.00995636,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -0.64208984,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
},
|
||||
{
|
||||
"id": 29945,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 29953,
|
||||
"logprob": -0.46533203,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 29992,
|
||||
"logprob": -0.5336914,
|
||||
"special": false,
|
||||
"text": "@"
|
||||
},
|
||||
{
|
||||
"id": 21980,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": "gmail"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.00088739395,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.0022735596,
|
||||
"special": false,
|
||||
"text": "com"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
@ -355,124 +474,5 @@
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1024,
|
||||
"logprob": -10.578125,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.0332031,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -9.171875,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.04257202,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -2.4785156,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 4876,
|
||||
"logprob": -10.7890625,
|
||||
"text": "email"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.32495117,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -9.4921875,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": -0.7709961,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29941,
|
||||
"logprob": -0.00995636,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -0.64208984,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
},
|
||||
{
|
||||
"id": 29945,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 29953,
|
||||
"logprob": -0.46533203,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 29992,
|
||||
"logprob": -0.5336914,
|
||||
"special": false,
|
||||
"text": "@"
|
||||
},
|
||||
{
|
||||
"id": 21980,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": "gmail"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.00088739395,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.0022735596,
|
||||
"special": false,
|
||||
"text": "com"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
}
|
||||
]
|
@ -0,0 +1,274 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 30,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 5235,
|
||||
"logprob": -10.061389,
|
||||
"text": "info"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.2349052,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -10.626516,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.08372568,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 8753,
|
||||
"logprob": -7.5279083,
|
||||
"text": "hol"
|
||||
},
|
||||
{
|
||||
"id": 17559,
|
||||
"logprob": -3.8427715,
|
||||
"text": "tz"
|
||||
},
|
||||
{
|
||||
"id": 763,
|
||||
"logprob": -10.143592,
|
||||
"text": "like"
|
||||
},
|
||||
{
|
||||
"id": 10697,
|
||||
"logprob": -10.200588,
|
||||
"text": "trees"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -2.5744739,
|
||||
"text": "and"
|
||||
},
|
||||
{
|
||||
"id": 756,
|
||||
"logprob": -7.4822097,
|
||||
"text": "has"
|
||||
},
|
||||
{
|
||||
"id": 1023,
|
||||
"logprob": -5.043413,
|
||||
"text": "two"
|
||||
},
|
||||
{
|
||||
"id": 274,
|
||||
"logprob": -5.326814,
|
||||
"text": "c"
|
||||
},
|
||||
{
|
||||
"id": 1446,
|
||||
"logprob": -0.67299384,
|
||||
"text": "ats"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.999048,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 29871,
|
||||
"logprob": -4.2404404,
|
||||
"text": ""
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 6377,
|
||||
"logprob": -0.1497998,
|
||||
"special": false,
|
||||
"text": "{\""
|
||||
},
|
||||
{
|
||||
"id": 29888,
|
||||
"logprob": -0.1359236,
|
||||
"special": false,
|
||||
"text": "f"
|
||||
},
|
||||
{
|
||||
"id": 12935,
|
||||
"logprob": -0.01771052,
|
||||
"special": false,
|
||||
"text": "irs"
|
||||
},
|
||||
{
|
||||
"id": 29873,
|
||||
"logprob": -0.00084543246,
|
||||
"special": false,
|
||||
"text": "t"
|
||||
},
|
||||
{
|
||||
"id": 1170,
|
||||
"logprob": -0.0053624124,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.13352497,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 19504,
|
||||
"logprob": -0.8816582,
|
||||
"special": false,
|
||||
"text": "David"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.1636697,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 29882,
|
||||
"logprob": -0.08828322,
|
||||
"special": false,
|
||||
"text": "h"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": -0.66238964,
|
||||
"special": false,
|
||||
"text": "ob"
|
||||
},
|
||||
{
|
||||
"id": 1609,
|
||||
"logprob": -5.566919e-05,
|
||||
"special": false,
|
||||
"text": "by"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.2296004,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 29911,
|
||||
"logprob": -2.3745353,
|
||||
"special": false,
|
||||
"text": "T"
|
||||
},
|
||||
{
|
||||
"id": 11003,
|
||||
"logprob": -0.032119535,
|
||||
"special": false,
|
||||
"text": "rees"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.22055298,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 4230,
|
||||
"logprob": -0.067228675,
|
||||
"special": false,
|
||||
"text": "last"
|
||||
},
|
||||
{
|
||||
"id": 1170,
|
||||
"logprob": -0.0035023084,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.004494921,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 29950,
|
||||
"logprob": -0.12524654,
|
||||
"special": false,
|
||||
"text": "H"
|
||||
},
|
||||
{
|
||||
"id": 14339,
|
||||
"logprob": -0.009601957,
|
||||
"special": false,
|
||||
"text": "olt"
|
||||
},
|
||||
{
|
||||
"id": 29920,
|
||||
"logprob": -0.00041619223,
|
||||
"special": false,
|
||||
"text": "z"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.116980776,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 29876,
|
||||
"logprob": -0.2994127,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 398,
|
||||
"logprob": -0.0030563807,
|
||||
"special": false,
|
||||
"text": "um"
|
||||
},
|
||||
{
|
||||
"id": 29907,
|
||||
"logprob": -0.37736154,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1446,
|
||||
"logprob": -0.00031073033,
|
||||
"special": false,
|
||||
"text": "ats"
|
||||
},
|
||||
{
|
||||
"id": 1115,
|
||||
"logprob": -0.0021851014,
|
||||
"special": false,
|
||||
"text": "\":"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.07180126,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29913,
|
||||
"logprob": -0.018707855,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
|
||||
}
|
211
integration-tests/models/test_flash_grammar_llama.py
Normal file
211
integration-tests/models/test_flash_grammar_llama.py
Normal file
@ -0,0 +1,211 @@
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||
await flash_llama_grammar_handle.health(300)
|
||||
return flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "42.1.1.101"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
max_new_tokens=100,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Json, # "json"
|
||||
"value": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"$id": "https://example.com/person.schema.json",
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Person",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s first name.",
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s last name.",
|
||||
},
|
||||
"hobby": {
|
||||
"description": "The person'''s hobby.",
|
||||
"type": "string",
|
||||
},
|
||||
"numCats": {
|
||||
"description": "The number of cats the person has.",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 30
|
||||
assert (
|
||||
response.generated_text
|
||||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_load(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_grammar,
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
expected = "123456@gmail.com"
|
||||
|
||||
for response in responses:
|
||||
assert response.generated_text == expected
|
||||
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
# this is the same as the above test, but only fires off a single request
|
||||
# this is only to ensure that the parallel and single inference produce the same result
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_single_load_instance(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
# assert response.details.generated_tokens == 30
|
||||
assert response.generated_text == "123456@gmail.com"
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def non_flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
num_shard=1,
|
||||
disable_grammar_support=False,
|
||||
use_flash_attention=False,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
||||
await non_flash_llama_grammar_handle.health(300)
|
||||
return non_flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||
response = await non_flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
max_new_tokens=100,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Json,
|
||||
"value": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"$id": "https://example.com/person.schema.json",
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Person",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s first name.",
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s last name.",
|
||||
},
|
||||
"hobby": {
|
||||
"description": "The person'''s hobby.",
|
||||
"type": "string",
|
||||
},
|
||||
"numCats": {
|
||||
"description": "The number of cats the person has.",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 30
|
||||
assert (
|
||||
response.generated_text
|
||||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||
)
|
||||
assert response == response_snapshot
|
@ -4,148 +4,6 @@ import json
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||
await flash_llama_grammar_handle.health(300)
|
||||
return flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "42.1.1.101"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
max_new_tokens=100,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Json, # "json"
|
||||
"value": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"$id": "https://example.com/person.schema.json",
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Person",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s first name.",
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s last name.",
|
||||
},
|
||||
"hobby": {
|
||||
"description": "The person'''s hobby.",
|
||||
"type": "string",
|
||||
},
|
||||
"numCats": {
|
||||
"description": "The number of cats the person has.",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 30
|
||||
assert (
|
||||
response.generated_text
|
||||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_load(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_grammar,
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
expected = "123456@gmail.com"
|
||||
|
||||
for response in responses:
|
||||
assert response.generated_text == expected
|
||||
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
# this is the same as the above test, but only fires off a single request
|
||||
# this is only to ensure that the parallel and single inference produce the same result
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_single_load_instance(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
# assert response.details.generated_tokens == 30
|
||||
assert response.generated_text == "123456@gmail.com"
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def non_flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
|
Loading…
Reference in New Issue
Block a user