feat: improve client typing and update tests

This commit is contained in:
drbh 2024-02-13 23:59:56 +00:00
parent 3df37fa941
commit 8b9430fb68
6 changed files with 132 additions and 82 deletions

View File

@ -10,6 +10,7 @@ from text_generation.types import (
Response, Response,
Request, Request,
Parameters, Parameters,
Grammar,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
@ -76,7 +77,7 @@ class Client:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "", grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -171,7 +172,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "", grammar: Grammar = "",
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -330,7 +331,7 @@ class AsyncClient:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "", grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -424,7 +425,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: str = "", grammar: Optional[Grammar] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously

View File

@ -1,10 +1,24 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List from typing import Optional, List, Union
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
# enum for grammar type
class GrammarType(str, Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar(BaseModel):
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
@ -42,7 +56,7 @@ class Parameters(BaseModel):
# Return the N most likely tokens at each step # Return the N most likely tokens at each step
top_n_tokens: Optional[int] = None top_n_tokens: Optional[int] = None
# grammar to use for generation # grammar to use for generation
grammar: Optional[str] = None grammar: Optional[Grammar] = None
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -111,6 +125,14 @@ class Parameters(BaseModel):
raise ValidationError("`top_n_tokens` must be strictly positive") raise ValidationError("`top_n_tokens` must be strictly positive")
return v return v
@validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
raise ValidationError("`value` cannot be empty for `regex` grammar")
if v.type == GrammarType.Json and not v.value:
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt

View File

@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@ -139,6 +146,7 @@ class ResponseComparator(JSONSnapshotExtension):
response.details, other.details response.details, other.details
) )
# print(serialized_data)
serialized_data = convert_data(serialized_data) serialized_data = convert_data(serialized_data)
snapshot_data = convert_data(snapshot_data) snapshot_data = convert_data(snapshot_data)
@ -381,7 +389,7 @@ def generate_load():
max_new_tokens: int, max_new_tokens: int,
n: int, n: int,
seed: Optional[int] = None, seed: Optional[int] = None,
grammar: Optional[str] = "", grammar: Optional[Grammar] = None,
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [

View File

@ -16,7 +16,7 @@
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -3.2265625, "logprob": -3.2324219,
"text": ":" "text": ":"
}, },
{ {
@ -41,7 +41,7 @@
}, },
{ {
"id": 763, "id": 763,
"logprob": -10.1484375, "logprob": -10.140625,
"text": "like" "text": "like"
}, },
{ {
@ -51,7 +51,7 @@
}, },
{ {
"id": 322, "id": 322,
"logprob": -2.5683594, "logprob": -2.5742188,
"text": "and" "text": "and"
}, },
{ {
@ -61,22 +61,22 @@
}, },
{ {
"id": 1023, "id": 1023,
"logprob": -5.0546875, "logprob": -5.0507812,
"text": "two" "text": "two"
}, },
{ {
"id": 274, "id": 274,
"logprob": -5.3125, "logprob": -5.3164062,
"text": "c" "text": "c"
}, },
{ {
"id": 1446, "id": 1446,
"logprob": -0.6665039, "logprob": -0.6694336,
"text": "ats" "text": "ats"
}, },
{ {
"id": 29889, "id": 29889,
"logprob": -1.0009766, "logprob": -0.9995117,
"text": "." "text": "."
}, },
{ {
@ -89,61 +89,61 @@
"tokens": [ "tokens": [
{ {
"id": 6377, "id": 6377,
"logprob": -0.15002441, "logprob": -0.14916992,
"special": false, "special": false,
"text": "{\"" "text": "{\""
}, },
{ {
"id": 29888, "id": 29888,
"logprob": -0.13549805, "logprob": -0.13598633,
"special": false, "special": false,
"text": "f" "text": "f"
}, },
{ {
"id": 12935, "id": 12935,
"logprob": -0.017562866, "logprob": -0.017669678,
"special": false, "special": false,
"text": "irs" "text": "irs"
}, },
{ {
"id": 29873, "id": 29873,
"logprob": -0.0008444786, "logprob": -0.00085639954,
"special": false, "special": false,
"text": "t" "text": "t"
}, },
{ {
"id": 1170, "id": 1170,
"logprob": -0.0053634644, "logprob": -0.0054016113,
"special": false, "special": false,
"text": "Name" "text": "Name"
}, },
{ {
"id": 4710, "id": 4710,
"logprob": -0.13537598, "logprob": -0.13549805,
"special": false, "special": false,
"text": "\":\"" "text": "\":\""
}, },
{ {
"id": 19504, "id": 19504,
"logprob": -0.8886719, "logprob": -0.8852539,
"special": false, "special": false,
"text": "David" "text": "David"
}, },
{ {
"id": 3284, "id": 3284,
"logprob": -0.16381836, "logprob": -0.16394043,
"special": false, "special": false,
"text": "\",\"" "text": "\",\""
}, },
{ {
"id": 4230, "id": 4230,
"logprob": -0.02017212, "logprob": -0.020492554,
"special": false, "special": false,
"text": "last" "text": "last"
}, },
{ {
"id": 1170, "id": 1170,
"logprob": -0.0013923645, "logprob": -0.0013818741,
"special": false, "special": false,
"text": "Name" "text": "Name"
}, },
@ -155,37 +155,37 @@
}, },
{ {
"id": 29950, "id": 29950,
"logprob": -0.11407471, "logprob": -0.11578369,
"special": false, "special": false,
"text": "H" "text": "H"
}, },
{ {
"id": 14339, "id": 14339,
"logprob": -0.0040626526, "logprob": -0.004131317,
"special": false, "special": false,
"text": "olt" "text": "olt"
}, },
{ {
"id": 29920, "id": 29920,
"logprob": -0.0032863617, "logprob": -0.0033359528,
"special": false, "special": false,
"text": "z" "text": "z"
}, },
{ {
"id": 3284, "id": 3284,
"logprob": -0.20507812, "logprob": -0.20471191,
"special": false, "special": false,
"text": "\",\"" "text": "\",\""
}, },
{ {
"id": 29882, "id": 29882,
"logprob": -0.0068740845, "logprob": -0.0069274902,
"special": false, "special": false,
"text": "h" "text": "h"
}, },
{ {
"id": 20838, "id": 20838,
"logprob": -0.19714355, "logprob": -0.19580078,
"special": false, "special": false,
"text": "obb" "text": "obb"
}, },
@ -197,37 +197,37 @@
}, },
{ {
"id": 4710, "id": 4710,
"logprob": -0.31860352, "logprob": -0.32080078,
"special": false, "special": false,
"text": "\":\"" "text": "\":\""
}, },
{ {
"id": 29911, "id": 29911,
"logprob": -2.09375, "logprob": -2.1035156,
"special": false, "special": false,
"text": "T" "text": "T"
}, },
{ {
"id": 11003, "id": 11003,
"logprob": -0.02053833, "logprob": -0.020767212,
"special": false, "special": false,
"text": "rees" "text": "rees"
}, },
{ {
"id": 3284, "id": 3284,
"logprob": -0.59814453, "logprob": -0.6010742,
"special": false, "special": false,
"text": "\",\"" "text": "\",\""
}, },
{ {
"id": 29876, "id": 29876,
"logprob": -0.5732422, "logprob": -0.57666016,
"special": false, "special": false,
"text": "n" "text": "n"
}, },
{ {
"id": 398, "id": 398,
"logprob": -0.006198883, "logprob": -0.0061073303,
"special": false, "special": false,
"text": "um" "text": "um"
}, },
@ -245,19 +245,19 @@
}, },
{ {
"id": 1115, "id": 1115,
"logprob": -0.002117157, "logprob": -0.0021018982,
"special": false, "special": false,
"text": "\":" "text": "\":"
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.089416504, "logprob": -0.08996582,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
{ {
"id": 29913, "id": 29913,
"logprob": -0.021835327, "logprob": -0.021697998,
"special": false, "special": false,
"text": "}" "text": "}"
}, },

View File

@ -11,12 +11,12 @@
}, },
{ {
"id": 806, "id": 806,
"logprob": -11.90625, "logprob": -11.890625,
"text": "Wh" "text": "Wh"
}, },
{ {
"id": 1446, "id": 1446,
"logprob": -3.6660156, "logprob": -3.6699219,
"text": "ats" "text": "ats"
}, },
{ {
@ -26,17 +26,17 @@
}, },
{ {
"id": 468, "id": 468,
"logprob": -8.0625, "logprob": -8.0703125,
"text": "og" "text": "og"
}, },
{ {
"id": 793, "id": 793,
"logprob": -2.1816406, "logprob": -2.1875,
"text": "les" "text": "les"
}, },
{ {
"id": 16332, "id": 16332,
"logprob": -9.71875, "logprob": -9.7109375,
"text": "DNS" "text": "DNS"
} }
], ],
@ -44,13 +44,13 @@
"tokens": [ "tokens": [
{ {
"id": 29946, "id": 29946,
"logprob": -1.4736328, "logprob": -1.4765625,
"special": false, "special": false,
"text": "4" "text": "4"
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.91845703, "logprob": -0.9199219,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
@ -62,43 +62,43 @@
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -1.1386719, "logprob": -1.1367188,
"special": false, "special": false,
"text": "1" "text": "1"
}, },
{ {
"id": 29889, "id": 29889,
"logprob": -1.4638672, "logprob": -1.4648438,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -0.40771484, "logprob": -0.40722656,
"special": false, "special": false,
"text": "1" "text": "1"
}, },
{ {
"id": 29889, "id": 29889,
"logprob": -0.17553711, "logprob": -0.17419434,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -0.20776367, "logprob": -0.20251465,
"special": false, "special": false,
"text": "1" "text": "1"
}, },
{ {
"id": 29900, "id": 29900,
"logprob": -1.5546875, "logprob": -1.5527344,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -1.3681641, "logprob": -1.3710938,
"special": false, "special": false,
"text": "1" "text": "1"
} }

View File

@ -1,10 +1,14 @@
import pytest import pytest
import json import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher): def flash_llama_grammar_handle(launcher):
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True) as handle: with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True
) as handle:
yield handle yield handle
@ -33,7 +37,10 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
max_new_tokens=10, max_new_tokens=10,
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
grammar="((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
@ -49,31 +56,37 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
max_new_tokens=100, max_new_tokens=100,
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
grammar=json.dumps( grammar={
{ "type": GrammarType.Json, # "json"
"$id": "https://example.com/person.schema.json", "value": json.dumps(
"$schema": "https://json-schema.org/draft/2020-12/schema", {
"title": "Person", "type": "object",
"type": "object", "$id": "https://example.com/person.schema.json",
"properties": { "$schema": "https://json-schema.org/draft/2020-12/schema",
"firstName": { "title": "Person",
"type": "string", "properties": {
"description": "The person'''s first name.", "firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
}, },
"lastName": { "required": ["firstName", "lastName", "hobby", "numCats"],
"type": "string", }
"description": "The person'''s last name.", ),
}, },
"hobby": {"description": "The person'''s hobby.", "type": "string"},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
) )
assert response.details.generated_tokens == 30 assert response.details.generated_tokens == 30
@ -96,7 +109,10 @@ async def test_flash_llama_grammar_load(
n=4, n=4,
stop_sequences=[".com"], stop_sequences=[".com"],
seed=0, seed=0,
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
) )
assert len(responses) == 4 assert len(responses) == 4
@ -123,7 +139,10 @@ async def test_flash_llama_grammar_single_load_instance(
max_new_tokens=10, max_new_tokens=10,
stop_sequences=[".com"], stop_sequences=[".com"],
seed=0, seed=0,
grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
) )
# assert response.details.generated_tokens == 30 # assert response.details.generated_tokens == 30