mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Expose ignore_eos_token
to client
Signed-off-by: caiyesd <caiyesd@gmail.com>
This commit is contained in:
parent
c4422e5678
commit
608c5c93b2
@ -75,6 +75,7 @@ class Client:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
|
ignore_eos_token: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
@ -113,6 +114,8 @@ class Client:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
decoder_input_details (`bool`):
|
decoder_input_details (`bool`):
|
||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
|
ignore_eos_token (`bool`):
|
||||||
|
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -134,6 +137,7 @@ class Client:
|
|||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
decoder_input_details=decoder_input_details,
|
decoder_input_details=decoder_input_details,
|
||||||
|
ignore_eos_token=ignore_eos_token,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
@ -164,6 +168,7 @@ class Client:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
ignore_eos_token: bool = False,
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens
|
Given a prompt, generate the following stream of tokens
|
||||||
@ -198,6 +203,8 @@ class Client:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
ignore_eos_token (`bool`):
|
||||||
|
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterator[StreamResponse]: stream of generated tokens
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
@ -219,6 +226,7 @@ class Client:
|
|||||||
truncate=truncate,
|
truncate=truncate,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
ignore_eos_token=ignore_eos_token,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
@ -317,6 +325,7 @@ class AsyncClient:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
|
ignore_eos_token: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
@ -355,6 +364,8 @@ class AsyncClient:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
decoder_input_details (`bool`):
|
decoder_input_details (`bool`):
|
||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
|
ignore_eos_token (`bool`):
|
||||||
|
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -376,6 +387,7 @@ class AsyncClient:
|
|||||||
truncate=truncate,
|
truncate=truncate,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
ignore_eos_token=ignore_eos_token,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
@ -404,6 +416,7 @@ class AsyncClient:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
ignore_eos_token: bool = False,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
@ -438,6 +451,8 @@ class AsyncClient:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
ignore_eos_token (`bool`):
|
||||||
|
Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
@ -459,6 +474,7 @@ class AsyncClient:
|
|||||||
truncate=truncate,
|
truncate=truncate,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
ignore_eos_token=ignore_eos_token,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
@ -39,6 +39,8 @@ class Parameters(BaseModel):
|
|||||||
details: bool = False
|
details: bool = False
|
||||||
# Get decoder input token logprobs and ids
|
# Get decoder input token logprobs and ids
|
||||||
decoder_input_details: bool = False
|
decoder_input_details: bool = False
|
||||||
|
# Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
|
||||||
|
ignore_eos_token: bool = False
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
|
@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = "null"
|
example = "null"
|
||||||
)]
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "false")]
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> u32 {
|
||||||
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
details: false,
|
details: false,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
ignore_eos_token: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,6 +142,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
|
ignore_eos_token,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -251,7 +252,7 @@ impl Validation {
|
|||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
ignore_eos_token: false,
|
ignore_eos_token,
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||||
|
Loading…
Reference in New Issue
Block a user