mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve client typing and update tests
This commit is contained in:
parent
3df37fa941
commit
8b9430fb68
@ -10,6 +10,7 @@ from text_generation.types import (
|
|||||||
Response,
|
Response,
|
||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
|
Grammar,
|
||||||
)
|
)
|
||||||
from text_generation.errors import parse_error
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class Client:
|
|||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
grammar: str = "",
|
grammar: Optional[Grammar] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
@ -171,7 +172,7 @@ class Client:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
grammar: str = "",
|
grammar: Grammar = "",
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens
|
Given a prompt, generate the following stream of tokens
|
||||||
@ -330,7 +331,7 @@ class AsyncClient:
|
|||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
grammar: str = "",
|
grammar: Optional[Grammar] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
@ -424,7 +425,7 @@ class AsyncClient:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
grammar: str = "",
|
grammar: Optional[Grammar] = None,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
|
@ -1,10 +1,24 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
# enum for grammar type
|
||||||
|
class GrammarType(str, Enum):
|
||||||
|
Json = "json"
|
||||||
|
Regex = "regex"
|
||||||
|
|
||||||
|
|
||||||
|
# Grammar type and value
|
||||||
|
class Grammar(BaseModel):
|
||||||
|
# Grammar type
|
||||||
|
type: GrammarType
|
||||||
|
# Grammar value
|
||||||
|
value: Union[str, dict]
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
@ -42,7 +56,7 @@ class Parameters(BaseModel):
|
|||||||
# Return the N most likely tokens at each step
|
# Return the N most likely tokens at each step
|
||||||
top_n_tokens: Optional[int] = None
|
top_n_tokens: Optional[int] = None
|
||||||
# grammar to use for generation
|
# grammar to use for generation
|
||||||
grammar: Optional[str] = None
|
grammar: Optional[Grammar] = None
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
@ -111,6 +125,14 @@ class Parameters(BaseModel):
|
|||||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@validator("grammar")
|
||||||
|
def valid_grammar(cls, v):
|
||||||
|
if v is not None:
|
||||||
|
if v.type == GrammarType.Regex and not v.value:
|
||||||
|
raise ValidationError("`value` cannot be empty for `regex` grammar")
|
||||||
|
if v.type == GrammarType.Json and not v.value:
|
||||||
|
raise ValidationError("`value` cannot be empty for `json` grammar")
|
||||||
|
return v
|
||||||
|
|
||||||
class Request(BaseModel):
|
class Request(BaseModel):
|
||||||
# Prompt
|
# Prompt
|
||||||
|
@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
|||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
|
from text_generation.types import (
|
||||||
|
Response,
|
||||||
|
Details,
|
||||||
|
InputToken,
|
||||||
|
Token,
|
||||||
|
BestOfSequence,
|
||||||
|
Grammar,
|
||||||
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
@ -139,6 +146,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
response.details, other.details
|
response.details, other.details
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# print(serialized_data)
|
||||||
serialized_data = convert_data(serialized_data)
|
serialized_data = convert_data(serialized_data)
|
||||||
snapshot_data = convert_data(snapshot_data)
|
snapshot_data = convert_data(snapshot_data)
|
||||||
|
|
||||||
@ -381,7 +389,7 @@ def generate_load():
|
|||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
n: int,
|
n: int,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
grammar: Optional[str] = "",
|
grammar: Optional[Grammar] = None,
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -3.2265625,
|
"logprob": -3.2324219,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -41,7 +41,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 763,
|
"id": 763,
|
||||||
"logprob": -10.1484375,
|
"logprob": -10.140625,
|
||||||
"text": "like"
|
"text": "like"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -51,7 +51,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 322,
|
"id": 322,
|
||||||
"logprob": -2.5683594,
|
"logprob": -2.5742188,
|
||||||
"text": "and"
|
"text": "and"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -61,22 +61,22 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1023,
|
"id": 1023,
|
||||||
"logprob": -5.0546875,
|
"logprob": -5.0507812,
|
||||||
"text": "two"
|
"text": "two"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 274,
|
"id": 274,
|
||||||
"logprob": -5.3125,
|
"logprob": -5.3164062,
|
||||||
"text": "c"
|
"text": "c"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1446,
|
"id": 1446,
|
||||||
"logprob": -0.6665039,
|
"logprob": -0.6694336,
|
||||||
"text": "ats"
|
"text": "ats"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29889,
|
"id": 29889,
|
||||||
"logprob": -1.0009766,
|
"logprob": -0.9995117,
|
||||||
"text": "."
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -89,61 +89,61 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 6377,
|
"id": 6377,
|
||||||
"logprob": -0.15002441,
|
"logprob": -0.14916992,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "{\""
|
"text": "{\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29888,
|
"id": 29888,
|
||||||
"logprob": -0.13549805,
|
"logprob": -0.13598633,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "f"
|
"text": "f"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12935,
|
"id": 12935,
|
||||||
"logprob": -0.017562866,
|
"logprob": -0.017669678,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "irs"
|
"text": "irs"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29873,
|
"id": 29873,
|
||||||
"logprob": -0.0008444786,
|
"logprob": -0.00085639954,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "t"
|
"text": "t"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1170,
|
"id": 1170,
|
||||||
"logprob": -0.0053634644,
|
"logprob": -0.0054016113,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Name"
|
"text": "Name"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4710,
|
"id": 4710,
|
||||||
"logprob": -0.13537598,
|
"logprob": -0.13549805,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\":\""
|
"text": "\":\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 19504,
|
"id": 19504,
|
||||||
"logprob": -0.8886719,
|
"logprob": -0.8852539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "David"
|
"text": "David"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3284,
|
"id": 3284,
|
||||||
"logprob": -0.16381836,
|
"logprob": -0.16394043,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\",\""
|
"text": "\",\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4230,
|
"id": 4230,
|
||||||
"logprob": -0.02017212,
|
"logprob": -0.020492554,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "last"
|
"text": "last"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1170,
|
"id": 1170,
|
||||||
"logprob": -0.0013923645,
|
"logprob": -0.0013818741,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Name"
|
"text": "Name"
|
||||||
},
|
},
|
||||||
@ -155,37 +155,37 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29950,
|
"id": 29950,
|
||||||
"logprob": -0.11407471,
|
"logprob": -0.11578369,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "H"
|
"text": "H"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 14339,
|
"id": 14339,
|
||||||
"logprob": -0.0040626526,
|
"logprob": -0.004131317,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "olt"
|
"text": "olt"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29920,
|
"id": 29920,
|
||||||
"logprob": -0.0032863617,
|
"logprob": -0.0033359528,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "z"
|
"text": "z"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3284,
|
"id": 3284,
|
||||||
"logprob": -0.20507812,
|
"logprob": -0.20471191,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\",\""
|
"text": "\",\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29882,
|
"id": 29882,
|
||||||
"logprob": -0.0068740845,
|
"logprob": -0.0069274902,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "h"
|
"text": "h"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20838,
|
"id": 20838,
|
||||||
"logprob": -0.19714355,
|
"logprob": -0.19580078,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "obb"
|
"text": "obb"
|
||||||
},
|
},
|
||||||
@ -197,37 +197,37 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4710,
|
"id": 4710,
|
||||||
"logprob": -0.31860352,
|
"logprob": -0.32080078,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\":\""
|
"text": "\":\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29911,
|
"id": 29911,
|
||||||
"logprob": -2.09375,
|
"logprob": -2.1035156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "T"
|
"text": "T"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11003,
|
"id": 11003,
|
||||||
"logprob": -0.02053833,
|
"logprob": -0.020767212,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "rees"
|
"text": "rees"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3284,
|
"id": 3284,
|
||||||
"logprob": -0.59814453,
|
"logprob": -0.6010742,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\",\""
|
"text": "\",\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29876,
|
"id": 29876,
|
||||||
"logprob": -0.5732422,
|
"logprob": -0.57666016,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "n"
|
"text": "n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 398,
|
"id": 398,
|
||||||
"logprob": -0.006198883,
|
"logprob": -0.0061073303,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "um"
|
"text": "um"
|
||||||
},
|
},
|
||||||
@ -245,19 +245,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1115,
|
"id": 1115,
|
||||||
"logprob": -0.002117157,
|
"logprob": -0.0021018982,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\":"
|
"text": "\":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.089416504,
|
"logprob": -0.08996582,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29913,
|
"id": 29913,
|
||||||
"logprob": -0.021835327,
|
"logprob": -0.021697998,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "}"
|
"text": "}"
|
||||||
},
|
},
|
||||||
|
@ -11,12 +11,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 806,
|
"id": 806,
|
||||||
"logprob": -11.90625,
|
"logprob": -11.890625,
|
||||||
"text": "Wh"
|
"text": "Wh"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1446,
|
"id": 1446,
|
||||||
"logprob": -3.6660156,
|
"logprob": -3.6699219,
|
||||||
"text": "ats"
|
"text": "ats"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -26,17 +26,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 468,
|
"id": 468,
|
||||||
"logprob": -8.0625,
|
"logprob": -8.0703125,
|
||||||
"text": "og"
|
"text": "og"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 793,
|
"id": 793,
|
||||||
"logprob": -2.1816406,
|
"logprob": -2.1875,
|
||||||
"text": "les"
|
"text": "les"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16332,
|
"id": 16332,
|
||||||
"logprob": -9.71875,
|
"logprob": -9.7109375,
|
||||||
"text": "DNS"
|
"text": "DNS"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -44,13 +44,13 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 29946,
|
"id": 29946,
|
||||||
"logprob": -1.4736328,
|
"logprob": -1.4765625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "4"
|
"text": "4"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.91845703,
|
"logprob": -0.9199219,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
@ -62,43 +62,43 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29896,
|
"id": 29896,
|
||||||
"logprob": -1.1386719,
|
"logprob": -1.1367188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "1"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29889,
|
"id": 29889,
|
||||||
"logprob": -1.4638672,
|
"logprob": -1.4648438,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "."
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29896,
|
"id": 29896,
|
||||||
"logprob": -0.40771484,
|
"logprob": -0.40722656,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "1"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29889,
|
"id": 29889,
|
||||||
"logprob": -0.17553711,
|
"logprob": -0.17419434,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "."
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29896,
|
"id": 29896,
|
||||||
"logprob": -0.20776367,
|
"logprob": -0.20251465,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "1"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29900,
|
"id": 29900,
|
||||||
"logprob": -1.5546875,
|
"logprob": -1.5527344,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "0"
|
"text": "0"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29896,
|
"id": 29896,
|
||||||
"logprob": -1.3681641,
|
"logprob": -1.3710938,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "1"
|
"text": "1"
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_grammar_handle(launcher):
|
def flash_llama_grammar_handle(launcher):
|
||||||
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True) as handle:
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +37,10 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
|
|||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
grammar="((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
@ -49,12 +56,14 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|||||||
max_new_tokens=100,
|
max_new_tokens=100,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
grammar=json.dumps(
|
grammar={
|
||||||
|
"type": GrammarType.Json, # "json"
|
||||||
|
"value": json.dumps(
|
||||||
{
|
{
|
||||||
|
"type": "object",
|
||||||
"$id": "https://example.com/person.schema.json",
|
"$id": "https://example.com/person.schema.json",
|
||||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
"title": "Person",
|
"title": "Person",
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
"properties": {
|
||||||
"firstName": {
|
"firstName": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -64,7 +73,10 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The person'''s last name.",
|
"description": "The person'''s last name.",
|
||||||
},
|
},
|
||||||
"hobby": {"description": "The person'''s hobby.", "type": "string"},
|
"hobby": {
|
||||||
|
"description": "The person'''s hobby.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
"numCats": {
|
"numCats": {
|
||||||
"description": "The number of cats the person has.",
|
"description": "The number of cats the person has.",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
@ -74,6 +86,7 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|||||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 30
|
assert response.details.generated_tokens == 30
|
||||||
@ -96,7 +109,10 @@ async def test_flash_llama_grammar_load(
|
|||||||
n=4,
|
n=4,
|
||||||
stop_sequences=[".com"],
|
stop_sequences=[".com"],
|
||||||
seed=0,
|
seed=0,
|
||||||
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
@ -123,7 +139,10 @@ async def test_flash_llama_grammar_single_load_instance(
|
|||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
stop_sequences=[".com"],
|
stop_sequences=[".com"],
|
||||||
seed=0,
|
seed=0,
|
||||||
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert response.details.generated_tokens == 30
|
# assert response.details.generated_tokens == 30
|
||||||
|
Loading…
Reference in New Issue
Block a user