Expose ignore_eos_token to client

Signed-off-by: caiyesd <caiyesd@gmail.com>
This commit is contained in:
caiyesd 2023-08-12 16:14:05 +00:00
parent c4422e5678
commit 608c5c93b2
4 changed files with 24 additions and 1 deletions

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
ignore_eos_token: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
ignore_eos_token (`bool`):
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
Returns: Returns:
Response: generated response Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
ignore_eos_token=ignore_eos_token,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
ignore_eos_token: bool = False,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
ignore_eos_token (`bool`):
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
ignore_eos_token=ignore_eos_token,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
ignore_eos_token: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
ignore_eos_token (`bool`):
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
Returns: Returns:
Response: generated response Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
ignore_eos_token=ignore_eos_token,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
ignore_eos_token: bool = False,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
ignore_eos_token (`bool`):
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
ignore_eos_token=ignore_eos_token,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False details: bool = False
# Get decoder input token logprobs and ids # Get decoder input token logprobs and ids
decoder_input_details: bool = False decoder_input_details: bool = False
# Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
ignore_eos_token: bool = False
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):

View File

@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
example = "null" example = "null"
)] )]
pub seed: Option<u64>, pub seed: Option<u64>,
#[serde(default)]
#[schema(default = "false")]
pub ignore_eos_token: bool,
} }
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> u32 {
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
details: false, details: false,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
ignore_eos_token: false,
} }
} }

View File

@ -142,6 +142,7 @@ impl Validation {
seed, seed,
watermark, watermark,
decoder_input_details, decoder_input_details,
ignore_eos_token,
.. ..
} = request.parameters; } = request.parameters;
@ -251,7 +252,7 @@ impl Validation {
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
stop_sequences, stop_sequences,
ignore_eos_token: false, ignore_eos_token,
}; };
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);