From 8b9430fb68e7cba2bde3555a904adac5ab2ce5be Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Feb 2024 23:59:56 +0000 Subject: [PATCH] feat: improve client typing and update tests --- clients/python/text_generation/client.py | 9 ++- clients/python/text_generation/types.py | 26 ++++++- integration-tests/conftest.py | 12 ++- .../test_flash_llama_grammar_json.json | 64 ++++++++-------- .../test_flash_llama_grammar_regex.json | 28 +++---- .../models/test_grammar_llama.py | 75 ++++++++++++------- 6 files changed, 132 insertions(+), 82 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index d11faa2a..5afe5dfe 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -10,6 +10,7 @@ from text_generation.types import ( Response, Request, Parameters, + Grammar, ) from text_generation.errors import parse_error @@ -76,7 +77,7 @@ class Client: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "", + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text @@ -171,7 +172,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "", + grammar: Grammar = "", ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -330,7 +331,7 @@ class AsyncClient: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "", + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -424,7 +425,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, - grammar: str = "", + grammar: Optional[Grammar] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 3369a5fd..3426411b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,10 +1,24 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List +from typing import Optional, List, Union from text_generation.errors import ValidationError +# enum for grammar type +class GrammarType(str, Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar(BaseModel): + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False @@ -42,7 +56,7 @@ class Parameters(BaseModel): # Return the N most likely tokens at each step top_n_tokens: Optional[int] = None # grammar to use for generation - grammar: Optional[str] = None + grammar: Optional[Grammar] = None @validator("best_of") def valid_best_of(cls, field_value, values): @@ -111,6 +125,14 @@ class Parameters(BaseModel): raise ValidationError("`top_n_tokens` must be strictly positive") return v + @validator("grammar") + def valid_grammar(cls, v): + if v is not None: + if v.type == GrammarType.Regex and not v.value: + raise ValidationError("`value` cannot be empty for `regex` grammar") + if v.type == GrammarType.Json and not v.value: + raise ValidationError("`value` cannot be empty for `json` grammar") + return v class Request(BaseModel): # Prompt diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c97b039b..d499fee9 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from text_generation import AsyncClient -from text_generation.types import Response, Details, InputToken, Token, BestOfSequence +from text_generation.types import ( + Response, + Details, + InputToken, + Token, + BestOfSequence, + Grammar, +) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -139,6 +146,7 @@ class ResponseComparator(JSONSnapshotExtension): response.details, other.details ) + # print(serialized_data) serialized_data = convert_data(serialized_data) snapshot_data = convert_data(snapshot_data) @@ -381,7 +389,7 @@ def generate_load(): max_new_tokens: int, n: int, seed: Optional[int] = None, - grammar: Optional[str] = "", + grammar: Optional[Grammar] = None, stop_sequences: Optional[List[str]] = None, ) -> List[Response]: futures = [ diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json index d0e017c1..7b12b158 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -16,7 +16,7 @@ }, { "id": 29901, - "logprob": -3.2265625, + "logprob": -3.2324219, "text": ":" }, { @@ -41,7 +41,7 @@ }, { "id": 763, - "logprob": -10.1484375, + "logprob": -10.140625, "text": "like" }, { @@ -51,7 +51,7 @@ }, { "id": 322, - "logprob": -2.5683594, + "logprob": -2.5742188, "text": "and" }, { @@ -61,22 +61,22 @@ }, { "id": 1023, - "logprob": -5.0546875, + "logprob": -5.0507812, "text": "two" }, { "id": 274, - "logprob": -5.3125, + "logprob": -5.3164062, "text": "c" }, { "id": 1446, - "logprob": -0.6665039, + "logprob": -0.6694336, "text": "ats" }, { "id": 29889, - "logprob": -1.0009766, + "logprob": -0.9995117, "text": "." }, { @@ -89,61 +89,61 @@ "tokens": [ { "id": 6377, - "logprob": -0.15002441, + "logprob": -0.14916992, "special": false, "text": "{\"" }, { "id": 29888, - "logprob": -0.13549805, + "logprob": -0.13598633, "special": false, "text": "f" }, { "id": 12935, - "logprob": -0.017562866, + "logprob": -0.017669678, "special": false, "text": "irs" }, { "id": 29873, - "logprob": -0.0008444786, + "logprob": -0.00085639954, "special": false, "text": "t" }, { "id": 1170, - "logprob": -0.0053634644, + "logprob": -0.0054016113, "special": false, "text": "Name" }, { "id": 4710, - "logprob": -0.13537598, + "logprob": -0.13549805, "special": false, "text": "\":\"" }, { "id": 19504, - "logprob": -0.8886719, + "logprob": -0.8852539, "special": false, "text": "David" }, { "id": 3284, - "logprob": -0.16381836, + "logprob": -0.16394043, "special": false, "text": "\",\"" }, { "id": 4230, - "logprob": -0.02017212, + "logprob": -0.020492554, "special": false, "text": "last" }, { "id": 1170, - "logprob": -0.0013923645, + "logprob": -0.0013818741, "special": false, "text": "Name" }, @@ -155,37 +155,37 @@ }, { "id": 29950, - "logprob": -0.11407471, + "logprob": -0.11578369, "special": false, "text": "H" }, { "id": 14339, - "logprob": -0.0040626526, + "logprob": -0.004131317, "special": false, "text": "olt" }, { "id": 29920, - "logprob": -0.0032863617, + "logprob": -0.0033359528, "special": false, "text": "z" }, { "id": 3284, - "logprob": -0.20507812, + "logprob": -0.20471191, "special": false, "text": "\",\"" }, { "id": 29882, - "logprob": -0.0068740845, + "logprob": -0.0069274902, "special": false, "text": "h" }, { "id": 20838, - "logprob": -0.19714355, + "logprob": -0.19580078, "special": false, "text": "obb" }, @@ -197,37 +197,37 @@ }, { "id": 4710, - "logprob": -0.31860352, + "logprob": -0.32080078, "special": false, "text": "\":\"" }, { "id": 29911, - "logprob": -2.09375, + "logprob": -2.1035156, "special": false, "text": "T" }, { "id": 11003, - "logprob": -0.02053833, + "logprob": -0.020767212, "special": false, "text": "rees" }, { "id": 3284, - "logprob": -0.59814453, + "logprob": -0.6010742, "special": false, "text": "\",\"" }, { "id": 29876, - "logprob": -0.5732422, + "logprob": -0.57666016, "special": false, "text": "n" }, { "id": 398, - "logprob": -0.006198883, + "logprob": -0.0061073303, "special": false, "text": "um" }, @@ -245,19 +245,19 @@ }, { "id": 1115, - "logprob": -0.002117157, + "logprob": -0.0021018982, "special": false, "text": "\":" }, { "id": 29906, - "logprob": -0.089416504, + "logprob": -0.08996582, "special": false, "text": "2" }, { "id": 29913, - "logprob": -0.021835327, + "logprob": -0.021697998, "special": false, "text": "}" }, diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json index 71dad72c..1ba9ae1e 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json @@ -11,12 +11,12 @@ }, { "id": 806, - "logprob": -11.90625, + "logprob": -11.890625, "text": "Wh" }, { "id": 1446, - "logprob": -3.6660156, + "logprob": -3.6699219, "text": "ats" }, { @@ -26,17 +26,17 @@ }, { "id": 468, - "logprob": -8.0625, + "logprob": -8.0703125, "text": "og" }, { "id": 793, - "logprob": -2.1816406, + "logprob": -2.1875, "text": "les" }, { "id": 16332, - "logprob": -9.71875, + "logprob": -9.7109375, "text": "DNS" } ], @@ -44,13 +44,13 @@ "tokens": [ { "id": 29946, - "logprob": -1.4736328, + "logprob": -1.4765625, "special": false, "text": "4" }, { "id": 29906, - "logprob": -0.91845703, + "logprob": -0.9199219, "special": false, "text": "2" }, @@ -62,43 +62,43 @@ }, { "id": 29896, - "logprob": -1.1386719, + "logprob": -1.1367188, "special": false, "text": "1" }, { "id": 29889, - "logprob": -1.4638672, + "logprob": -1.4648438, "special": false, "text": "." }, { "id": 29896, - "logprob": -0.40771484, + "logprob": -0.40722656, "special": false, "text": "1" }, { "id": 29889, - "logprob": -0.17553711, + "logprob": -0.17419434, "special": false, "text": "." }, { "id": 29896, - "logprob": -0.20776367, + "logprob": -0.20251465, "special": false, "text": "1" }, { "id": 29900, - "logprob": -1.5546875, + "logprob": -1.5527344, "special": false, "text": "0" }, { "id": 29896, - "logprob": -1.3681641, + "logprob": -1.3710938, "special": false, "text": "1" } diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index f4634fbd..3abe1077 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -1,10 +1,14 @@ import pytest import json +from text_generation.types import GrammarType + @pytest.fixture(scope="module") def flash_llama_grammar_handle(launcher): - with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True) as handle: + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, grammar_support=True + ) as handle: yield handle @@ -33,7 +37,10 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot) max_new_tokens=10, decoder_input_details=True, seed=0, - grammar="((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, ) assert response.details.generated_tokens == 10 @@ -49,31 +56,37 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): max_new_tokens=100, decoder_input_details=True, seed=0, - grammar=json.dumps( - { - "$id": "https://example.com/person.schema.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Person", - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The person'''s first name.", + grammar={ + "type": GrammarType.Json, # "json" + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, }, - "lastName": { - "type": "string", - "description": "The person'''s last name.", - }, - "hobby": {"description": "The person'''s hobby.", "type": "string"}, - "numCats": { - "description": "The number of cats the person has.", - "type": "integer", - "minimum": 0, - }, - }, - "required": ["firstName", "lastName", "hobby", "numCats"], - } - ), + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, ) assert response.details.generated_tokens == 30 @@ -96,7 +109,10 @@ async def test_flash_llama_grammar_load( n=4, stop_sequences=[".com"], seed=0, - grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, ) assert len(responses) == 4 @@ -123,7 +139,10 @@ async def test_flash_llama_grammar_single_load_instance( max_new_tokens=10, stop_sequences=[".com"], seed=0, - grammar="[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, ) # assert response.details.generated_tokens == 30