From 5e1473f0f80f144a6b028a12f2e164e0f089b437 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Mar 2023 13:48:58 +0100 Subject: [PATCH] feat(python-client): add new parameters --- clients/python/pyproject.toml | 2 +- clients/python/tests/test_types.py | 42 +++++++- clients/python/text_generation/__init__.py | 2 +- clients/python/text_generation/client.py | 52 +++++++++ clients/python/text_generation/types.py | 117 +++++++++++++++++++-- 5 files changed, 200 insertions(+), 15 deletions(-) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 0b8fa8c7..51ecce82 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.2.1" +version = "0.3.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py index d319b570..b32f3296 100644 --- a/clients/python/tests/test_types.py +++ b/clients/python/tests/test_types.py @@ -1,10 +1,17 @@ import pytest -from text_generation.types import Parameters +from text_generation.types import Parameters, Request from text_generation.errors import ValidationError def test_parameters_validation(): + # Test best_of + Parameters(best_of=1) + with pytest.raises(ValidationError): + Parameters(best_of=0) + with pytest.raises(ValidationError): + Parameters(best_of=-1) + # Test repetition_penalty Parameters(repetition_penalty=1) with pytest.raises(ValidationError): @@ -32,8 +39,39 @@ def test_parameters_validation(): Parameters(top_k=-1) # Test top_p - Parameters(top_p=1) + Parameters(top_p=0.5) with pytest.raises(ValidationError): Parameters(top_p=0) with pytest.raises(ValidationError): Parameters(top_p=-1) + with pytest.raises(ValidationError): + Parameters(top_p=1) + + # Test truncate + Parameters(truncate=1) + with pytest.raises(ValidationError): + Parameters(truncate=0) + with pytest.raises(ValidationError): + Parameters(truncate=-1) + + # Test typical_p + Parameters(typical_p=0.5) + with pytest.raises(ValidationError): + Parameters(typical_p=0) + with pytest.raises(ValidationError): + Parameters(typical_p=-1) + with pytest.raises(ValidationError): + Parameters(typical_p=1) + + +def test_request_validation(): + Request(inputs="test") + + with pytest.raises(ValidationError): + Request(inputs="") + + Request(inputs="test", stream=True) + Request(inputs="test", parameters=Parameters(best_of=2)) + + with pytest.raises(ValidationError): + Request(inputs="test", parameters=Parameters(best_of=2), stream=True) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index db09cdfa..46109833 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.1" +__version__ = "0.3.0" from text_generation.client import Client, AsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 23655240..e05a002e 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -56,6 +56,7 @@ class Client: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -63,6 +64,8 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Response: """ @@ -75,6 +78,8 @@ class Client: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -91,6 +96,11 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -99,6 +109,7 @@ class Client: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -109,6 +120,8 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -129,6 +142,7 @@ class Client: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -136,6 +150,8 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Iterator[StreamResponse]: """ @@ -148,6 +164,8 @@ class Client: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -164,6 +182,11 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -172,6 +195,7 @@ class Client: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -182,6 +206,8 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -261,6 +287,7 @@ class AsyncClient: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -268,6 +295,8 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Response: """ @@ -280,6 +309,8 @@ class AsyncClient: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -296,6 +327,11 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -304,6 +340,7 @@ class AsyncClient: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -314,6 +351,8 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -331,6 +370,7 @@ class AsyncClient: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -338,6 +378,8 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> AsyncIterator[StreamResponse]: """ @@ -350,6 +392,8 @@ class AsyncClient: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -366,6 +410,11 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -374,6 +423,7 @@ class AsyncClient: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -384,6 +434,8 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index d276b60e..954a0f2b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -6,27 +6,53 @@ from text_generation.errors import ValidationError class Parameters(BaseModel): + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Activate logits sampling do_sample: bool = False + # Maximum number of generated tokens max_new_tokens: int = 20 + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None + # Whether to prepend the prompt to the generated text return_full_text: bool = False + # Stop generating tokens if a member of `stop_sequences` is generated stop: List[str] = [] + # Random sampling seed seed: Optional[int] + # The value used to module the logits distribution. temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool = False + # Get generation details details: bool = False + @validator("best_of") + def valid_best_of(cls, v): + if v is not None and v <= 0: + raise ValidationError("`best_of` must be strictly positive") + return v + @validator("repetition_penalty") def valid_repetition_penalty(cls, v): - if v is not None and v is v <= 0: + if v is not None and v <= 0: raise ValidationError("`repetition_penalty` must be strictly positive") return v @validator("seed") def valid_seed(cls, v): - if v is not None and v is v < 0: + if v is not None and v < 0: raise ValidationError("`seed` must be positive") return v @@ -44,56 +70,125 @@ class Parameters(BaseModel): @validator("top_p") def valid_top_p(cls, v): - if v is not None and (v <= 0 or v > 1.0): - raise ValidationError("`top_p` must be > 0.0 and <= 1.0") + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`top_p` must be > 0.0 and < 1.0") + return v + + @validator("truncate") + def valid_truncate(cls, v): + if v is not None and v <= 0: + raise ValidationError("`truncate` must be strictly positive") + return v + + @validator("typical_p") + def valid_typical_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v class Request(BaseModel): + # Prompt inputs: str - parameters: Parameters + # Generation parameters + parameters: Optional[Parameters] + # Whether to stream output tokens stream: bool = False + @validator("inputs") + def valid_input(cls, v): + if not v: + raise ValidationError("`inputs` cannot be empty") + return v + @validator("stream") + def valid_best_of_stream(cls, field_value, values): + parameters = values["parameters"] + if ( + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value + ): + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) + return field_value + + +# Prompt tokens class PrefillToken(BaseModel): + # Token ID from the model tokenizer id: int + # Token text text: str + # Logprob + # Optional since the logprob of the first token cannot be computed logprob: Optional[float] +# Generated tokens class Token(BaseModel): + # Token ID from the model tokenizer id: int + # Token text text: str + # Logprob logprob: float + # Is the token a special token + # Can be used to ignore tokens when concatenating special: bool +# Generation finish reason class FinishReason(Enum): + # number of generated tokens == `max_new_tokens` Length = "length" + # the model generated its end of sequence token EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` StopSequence = "stop_sequence" +# `generate` details class Details(BaseModel): + # Generation finish reason finish_reason: FinishReason + # Number of generated tokens generated_tokens: int + # Sampling seed if sampling was activated seed: Optional[int] + # Prompt tokens prefill: List[PrefillToken] + # Generated tokens tokens: List[Token] -class StreamDetails(BaseModel): - finish_reason: FinishReason - generated_tokens: int - seed: Optional[int] - - +# `generate` return value class Response(BaseModel): + # Generated text generated_text: str + # Generation details details: Details +# `generate_stream` details +class StreamDetails(BaseModel): + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + + +# `generate_stream` return value class StreamResponse(BaseModel): + # Generated token token: Token + # Complete generated text + # Only available when the generation is finished generated_text: Optional[str] + # Generation details + # Only available when the generation is finished details: Optional[StreamDetails]