diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47..11ead1f8 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -75,6 +75,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + ignore_eos_token: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -113,6 +114,8 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + ignore_eos_token (`bool`): + Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. Returns: Response: generated response @@ -134,6 +137,7 @@ class Client: typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, + ignore_eos_token=ignore_eos_token, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -164,6 +168,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + ignore_eos_token: bool = False, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -198,6 +203,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + ignore_eos_token (`bool`): + Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. Returns: Iterator[StreamResponse]: stream of generated tokens @@ -219,6 +226,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + ignore_eos_token=ignore_eos_token, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -317,6 +325,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + ignore_eos_token: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -355,6 +364,8 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + ignore_eos_token (`bool`): + Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. Returns: Response: generated response @@ -376,6 +387,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + ignore_eos_token=ignore_eos_token, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -404,6 +416,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + ignore_eos_token: bool = False, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -438,6 +451,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + ignore_eos_token (`bool`): + Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -459,6 +474,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + ignore_eos_token=ignore_eos_token, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b63..7bcb6740 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -39,6 +39,8 @@ class Parameters(BaseModel): details: bool = False # Get decoder input token logprobs and ids decoder_input_details: bool = False + # Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. + ignore_eos_token: bool = False @validator("best_of") def valid_best_of(cls, field_value, values): diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a11..22f7f787 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters { example = "null" )] pub seed: Option, + #[serde(default)] + #[schema(default = "false")] + pub ignore_eos_token: bool, } fn default_max_new_tokens() -> u32 { @@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters { details: false, decoder_input_details: false, seed: None, + ignore_eos_token: false, } } diff --git a/router/src/validation.rs b/router/src/validation.rs index f967361f..04c0cfde 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -142,6 +142,7 @@ impl Validation { seed, watermark, decoder_input_details, + ignore_eos_token, .. } = request.parameters; @@ -251,7 +252,7 @@ impl Validation { let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, stop_sequences, - ignore_eos_token: false, + ignore_eos_token, }; metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);