mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Only return top_tokens field when requested
Mimics the behaviour of `best_of`. Also allows client compatibility with older versions
This commit is contained in:
parent
8b2847fcf8
commit
66705831a9
@ -75,6 +75,7 @@ class Client:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
@ -113,6 +114,8 @@ class Client:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
decoder_input_details (`bool`):
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -134,6 +137,7 @@ class Client:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
decoder_input_details=decoder_input_details,
|
||||
top_n_tokens=top_n_tokens
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
@ -164,6 +168,7 @@ class Client:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
@ -198,6 +203,8 @@ class Client:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
@ -219,6 +226,7 @@ class Client:
|
||||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
@ -317,6 +325,7 @@ class AsyncClient:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
@ -355,6 +364,8 @@ class AsyncClient:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
decoder_input_details (`bool`):
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -376,6 +387,7 @@ class AsyncClient:
|
||||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
@ -404,6 +416,7 @@ class AsyncClient:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
@ -438,6 +451,8 @@ class AsyncClient:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
@ -459,6 +474,7 @@ class AsyncClient:
|
||||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
@ -39,6 +39,8 @@ class Parameters(BaseModel):
|
||||
details: bool = False
|
||||
# Get decoder input token logprobs and ids
|
||||
decoder_input_details: bool = False
|
||||
# Return the N most likely tokens at each step
|
||||
top_n_tokens: Optional[int]
|
||||
|
||||
@validator("best_of")
|
||||
def valid_best_of(cls, field_value, values):
|
||||
@ -101,6 +103,12 @@ class Parameters(BaseModel):
|
||||
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||
return v
|
||||
|
||||
@validator("top_n_tokens")
|
||||
def valid_top_n_tokens(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||
return v
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
# Prompt
|
||||
@ -125,9 +133,7 @@ class Request(BaseModel):
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
):
|
||||
raise ValidationError(
|
||||
"`best_of` != 1 is not supported when `stream` == True"
|
||||
)
|
||||
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
|
||||
return field_value
|
||||
|
||||
|
||||
@ -180,8 +186,7 @@ class BestOfSequence(BaseModel):
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[List[Token]]
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
|
||||
|
||||
# `generate` details
|
||||
@ -196,9 +201,8 @@ class Details(BaseModel):
|
||||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[List[Token]]
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
@ -226,8 +230,7 @@ class StreamResponse(BaseModel):
|
||||
# Generated token
|
||||
token: Token
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[Token]
|
||||
top_tokens: Optional[List[Token]]
|
||||
# Complete generated text
|
||||
# Only available when the generation is finished
|
||||
generated_text: Optional[str]
|
||||
|
@ -239,6 +239,7 @@ pub(crate) struct BestOfSequence {
|
||||
pub seed: Option<u64>,
|
||||
pub prefill: Vec<PrefillToken>,
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
@ -254,6 +255,7 @@ pub(crate) struct Details {
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
@ -278,8 +280,8 @@ pub(crate) struct StreamDetails {
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
pub top_tokens: Option<Vec<Token>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Token>,
|
||||
#[schema(nullable = true, default = "null", example = "test")]
|
||||
pub generated_text: Option<String>,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
|
@ -198,6 +198,11 @@ async fn generate(
|
||||
.collect()
|
||||
});
|
||||
|
||||
// let top_tokens = match response.top_tokens.is_empty() {
|
||||
// true => None,
|
||||
// false => Some(response.top_tokens),
|
||||
// };
|
||||
|
||||
Some(Details {
|
||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
@ -376,12 +381,8 @@ async fn generate_stream(
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
<<<<<<< HEAD
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
=======
|
||||
let top_n_tokens = req.0.parameters.top_n_tokens;
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
>>>>>>> 7c014c7 (Add WIP support for returning top tokens)
|
||||
let top_n_tokens = req.parameters.top_n_tokens;
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
@ -401,7 +402,7 @@ async fn generate_stream(
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_n_tokens.and(Some(top_tokens)),
|
||||
top_tokens: top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
@ -463,7 +464,7 @@ async fn generate_stream(
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens:top_n_tokens.and(Some(top_tokens)),
|
||||
top_tokens: top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user