From c5a0b65c47be15ac85b363d8d7bacef6fc65dfb2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Mar 2023 14:39:00 +0100 Subject: [PATCH] use flan-t5 for tests --- clients/python/README.md | 18 ++++++++++++ clients/python/tests/conftest.py | 5 ---- clients/python/tests/test_client.py | 32 +++++++++++----------- clients/python/tests/test_inference_api.py | 8 +++--- clients/python/text_generation/types.py | 26 +++++++++++++++--- 5 files changed, 60 insertions(+), 29 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index 79b88377..f509e65c 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -133,6 +133,22 @@ class FinishReason(Enum): StopSequence = "stop_sequence" +# Additional sequences when using the `best_of` parameter +class BestOfSequence: + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + + # `generate` details class Details: # Generation finish reason @@ -145,6 +161,8 @@ class Details: prefill: List[PrefillToken] # Generated tokens tokens: List[Token] + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] # `generate` return value diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 4298623e..48734f0d 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -4,11 +4,6 @@ from text_generation import __version__ from huggingface_hub.utils import build_hf_headers -@pytest.fixture -def bloom_model(): - return "bigscience/bloom" - - @pytest.fixture def flan_t5_xxl(): return "google/flan-t5-xxl" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 2f96aa87..dac985bc 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -5,21 +5,21 @@ from text_generation.errors import NotFoundError, ValidationError from text_generation.types import FinishReason, PrefillToken, Token -def test_generate(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) response = client.generate("test", max_new_tokens=1) - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 assert response.details.prefill[0] == PrefillToken( - id=9234, text="test", logprob=None + id=0, text="", logprob=None ) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( - id=17, text=".", logprob=-1.75, special=False + id=3, text=" ", logprob=-1.984375, special=False ) @@ -35,8 +35,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers): client.generate("test", max_new_tokens=10_000) -def test_generate_stream(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate_stream(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) responses = [ response for response in client.generate_stream("test", max_new_tokens=1) ] @@ -44,7 +44,7 @@ def test_generate_stream(bloom_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None @@ -63,21 +63,21 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio -async def test_generate_async(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_async(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate("test", max_new_tokens=1) - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 assert response.details.prefill[0] == PrefillToken( - id=9234, text="test", logprob=None + id=0, text="", logprob=None ) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( - id=17, text=".", logprob=-1.75, special=False + id=3, text=" ", logprob=-1.984375, special=False ) @@ -96,8 +96,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio -async def test_generate_stream_async(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) responses = [ response async for response in client.generate_stream("test", max_new_tokens=1) ] @@ -105,7 +105,7 @@ async def test_generate_stream_async(bloom_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py index dc744940..79e503a3 100644 --- a/clients/python/tests/test_inference_api.py +++ b/clients/python/tests/test_inference_api.py @@ -14,8 +14,8 @@ def test_get_supported_models(): assert isinstance(get_supported_models(), list) -def test_client(bloom_model): - client = InferenceAPIClient(bloom_model) +def test_client(flan_t5_xxl): + client = InferenceAPIClient(flan_t5_xxl) assert isinstance(client, Client) @@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model): InferenceAPIClient(unsupported_model) -def test_async_client(bloom_model): - client = InferenceAPIAsyncClient(bloom_model) +def test_async_client(flan_t5_xxl): + client = InferenceAPIAsyncClient(flan_t5_xxl) assert isinstance(client, AsyncClient) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 954a0f2b..7ce5c7f6 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -105,10 +105,10 @@ class Request(BaseModel): def valid_best_of_stream(cls, field_value, values): parameters = values["parameters"] if ( - parameters is not None - and parameters.best_of is not None - and parameters.best_of > 1 - and field_value + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value ): raise ValidationError( "`best_of` != 1 is not supported when `stream` == True" @@ -150,6 +150,22 @@ class FinishReason(Enum): StopSequence = "stop_sequence" +# Additional sequences when using the `best_of` parameter +class BestOfSequence(BaseModel): + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + + # `generate` details class Details(BaseModel): # Generation finish reason @@ -162,6 +178,8 @@ class Details(BaseModel): prefill: List[PrefillToken] # Generated tokens tokens: List[Token] + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] # `generate` return value