feat: improve client typing and update tests

This commit is contained in:
drbh 2024-02-13 23:59:56 +00:00
parent 3df37fa941
commit 8b9430fb68
6 changed files with 132 additions and 82 deletions

View File

@ -10,6 +10,7 @@ from text_generation.types import (
Response,
Request,
Parameters,
Grammar,
)
from text_generation.errors import parse_error
@ -76,7 +77,7 @@ class Client:
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
grammar: str = "",
grammar: Optional[Grammar] = None,
) -> Response:
"""
Given a prompt, generate the following text
@ -171,7 +172,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
grammar: str = "",
grammar: Grammar = "",
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@ -330,7 +331,7 @@ class AsyncClient:
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
grammar: str = "",
grammar: Optional[Grammar] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -424,7 +425,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
grammar: str = "",
grammar: Optional[Grammar] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously

View File

@ -1,10 +1,24 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from typing import Optional, List, Union
from text_generation.errors import ValidationError
# enum for grammar type
class GrammarType(str, Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar(BaseModel):
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
@ -42,7 +56,7 @@ class Parameters(BaseModel):
# Return the N most likely tokens at each step
top_n_tokens: Optional[int] = None
# grammar to use for generation
grammar: Optional[str] = None
grammar: Optional[Grammar] = None
@validator("best_of")
def valid_best_of(cls, field_value, values):
@ -111,6 +125,14 @@ class Parameters(BaseModel):
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
@validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
raise ValidationError("`value` cannot be empty for `regex` grammar")
if v.type == GrammarType.Json and not v.value:
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel):
# Prompt

View File

@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@ -139,6 +146,7 @@ class ResponseComparator(JSONSnapshotExtension):
response.details, other.details
)
# print(serialized_data)
serialized_data = convert_data(serialized_data)
snapshot_data = convert_data(snapshot_data)
@ -381,7 +389,7 @@ def generate_load():
max_new_tokens: int,
n: int,
seed: Optional[int] = None,
grammar: Optional[str] = "",
grammar: Optional[Grammar] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]:
futures = [

View File

@ -16,7 +16,7 @@
},
{
"id": 29901,
"logprob": -3.2265625,
"logprob": -3.2324219,
"text": ":"
},
{
@ -41,7 +41,7 @@
},
{
"id": 763,
"logprob": -10.1484375,
"logprob": -10.140625,
"text": "like"
},
{
@ -51,7 +51,7 @@
},
{
"id": 322,
"logprob": -2.5683594,
"logprob": -2.5742188,
"text": "and"
},
{
@ -61,22 +61,22 @@
},
{
"id": 1023,
"logprob": -5.0546875,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3125,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6665039,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -1.0009766,
"logprob": -0.9995117,
"text": "."
},
{
@ -89,61 +89,61 @@
"tokens": [
{
"id": 6377,
"logprob": -0.15002441,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13549805,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017562866,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.0008444786,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0053634644,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13537598,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8886719,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16381836,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.02017212,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013923645,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
@ -155,37 +155,37 @@
},
{
"id": 29950,
"logprob": -0.11407471,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.0040626526,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0032863617,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20507812,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.0068740845,
"logprob": -0.0069274902,
"special": false,
"text": "h"
},
{
"id": 20838,
"logprob": -0.19714355,
"logprob": -0.19580078,
"special": false,
"text": "obb"
},
@ -197,37 +197,37 @@
},
{
"id": 4710,
"logprob": -0.31860352,
"logprob": -0.32080078,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.09375,
"logprob": -2.1035156,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.02053833,
"logprob": -0.020767212,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.59814453,
"logprob": -0.6010742,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.5732422,
"logprob": -0.57666016,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.006198883,
"logprob": -0.0061073303,
"special": false,
"text": "um"
},
@ -245,19 +245,19 @@
},
{
"id": 1115,
"logprob": -0.002117157,
"logprob": -0.0021018982,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.089416504,
"logprob": -0.08996582,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.021835327,
"logprob": -0.021697998,
"special": false,
"text": "}"
},

View File

@ -11,12 +11,12 @@
},
{
"id": 806,
"logprob": -11.90625,
"logprob": -11.890625,
"text": "Wh"
},
{
"id": 1446,
"logprob": -3.6660156,
"logprob": -3.6699219,
"text": "ats"
},
{
@ -26,17 +26,17 @@
},
{
"id": 468,
"logprob": -8.0625,
"logprob": -8.0703125,
"text": "og"
},
{
"id": 793,
"logprob": -2.1816406,
"logprob": -2.1875,
"text": "les"
},
{
"id": 16332,
"logprob": -9.71875,
"logprob": -9.7109375,
"text": "DNS"
}
],
@ -44,13 +44,13 @@
"tokens": [
{
"id": 29946,
"logprob": -1.4736328,
"logprob": -1.4765625,
"special": false,
"text": "4"
},
{
"id": 29906,
"logprob": -0.91845703,
"logprob": -0.9199219,
"special": false,
"text": "2"
},
@ -62,43 +62,43 @@
},
{
"id": 29896,
"logprob": -1.1386719,
"logprob": -1.1367188,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -1.4638672,
"logprob": -1.4648438,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.40771484,
"logprob": -0.40722656,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -0.17553711,
"logprob": -0.17419434,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.20776367,
"logprob": -0.20251465,
"special": false,
"text": "1"
},
{
"id": 29900,
"logprob": -1.5546875,
"logprob": -1.5527344,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": -1.3681641,
"logprob": -1.3710938,
"special": false,
"text": "1"
}

View File

@ -1,10 +1,14 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True) as handle:
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True
) as handle:
yield handle
@ -33,7 +37,10 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar="((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
@ -49,31 +56,37 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar=json.dumps(
{
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"type": "object",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
grammar={
"type": GrammarType.Json, # "json"
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {"description": "The person'''s hobby.", "type": "string"},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
@ -96,7 +109,10 @@ async def test_flash_llama_grammar_load(
n=4,
stop_sequences=[".com"],
seed=0,
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
@ -123,7 +139,10 @@ async def test_flash_llama_grammar_single_load_instance(
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30