From 1990d8633ca866cb9a84dbfb5b64fb669582b60d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Mar 2023 15:04:59 +0100 Subject: [PATCH] add validation --- clients/python/tests/test_client.py | 8 ++----- clients/python/tests/test_types.py | 9 +++++-- clients/python/text_generation/types.py | 31 +++++++++++++++++-------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index dac985bc..76ac80d3 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -14,9 +14,7 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken( - id=0, text="", logprob=None - ) + assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( id=3, text=" ", logprob=-1.984375, special=False @@ -72,9 +70,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken( - id=0, text="", logprob=None - ) + assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( id=3, text=" ", logprob=-1.984375, special=False diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py index b32f3296..4c9d4c89 100644 --- a/clients/python/tests/test_types.py +++ b/clients/python/tests/test_types.py @@ -11,6 +11,9 @@ def test_parameters_validation(): Parameters(best_of=0) with pytest.raises(ValidationError): Parameters(best_of=-1) + Parameters(best_of=2, do_sample=True) + with pytest.raises(ValidationError): + Parameters(best_of=2) # Test repetition_penalty Parameters(repetition_penalty=1) @@ -71,7 +74,9 @@ def test_request_validation(): Request(inputs="") Request(inputs="test", stream=True) - Request(inputs="test", parameters=Parameters(best_of=2)) + Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) with pytest.raises(ValidationError): - Request(inputs="test", parameters=Parameters(best_of=2), stream=True) + Request( + inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True + ) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 7ce5c7f6..ea2070b8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -6,8 +6,6 @@ from text_generation.errors import ValidationError class Parameters(BaseModel): - # Generate best_of sequences and return the one if the highest token logprobs - best_of: Optional[int] # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens @@ -33,16 +31,29 @@ class Parameters(BaseModel): # Typical Decoding mass # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool = False # Get generation details details: bool = False @validator("best_of") - def valid_best_of(cls, v): - if v is not None and v <= 0: - raise ValidationError("`best_of` must be strictly positive") - return v + def valid_best_of(cls, field_value, values): + if field_value is not None: + if field_value <= 0: + raise ValidationError("`best_of` must be strictly positive") + sampling = ( + values["do_sample"] + | (values["temperature"] is not None) + | (values["top_k"] is not None) + | (values["top_p"] is not None) + | (values["typical_p"] is not None) + ) + if field_value > 1 and not sampling: + raise ValidationError("you must use sampling when `best_of` is > 1") + + return field_value @validator("repetition_penalty") def valid_repetition_penalty(cls, v): @@ -105,10 +116,10 @@ class Request(BaseModel): def valid_best_of_stream(cls, field_value, values): parameters = values["parameters"] if ( - parameters is not None - and parameters.best_of is not None - and parameters.best_of > 1 - and field_value + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value ): raise ValidationError( "`best_of` != 1 is not supported when `stream` == True"