diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 775e7a6c..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special -def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", decoder_input_details=True) - - assert response.generated_text != "" - assert response.details.finish_reason == FinishReason.EndOfSequenceToken - assert response.details.generated_tokens > 1 - assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) - assert len(response.details.tokens) > 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " - assert not response.details.tokens[0].special - - def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 63b5258d..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 7fa8033e..aa02d8d8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: Optional[int] = None + max_new_tokens: int = 20 # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None