Only return top_tokens field when requested

Mimics the behaviour of `best_of`. Also allows client compatibility with
older versions
This commit is contained in:
Vincent Brouwers 2023-08-02 13:03:19 +00:00 committed by Nicolas Patry
parent 8b2847fcf8
commit 66705831a9
4 changed files with 41 additions and 19 deletions

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of")
def valid_best_of(cls, field_value, values):
@ -101,6 +103,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel):
# Prompt
@ -125,9 +133,7 @@ class Request(BaseModel):
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value
@ -180,8 +186,7 @@ class BestOfSequence(BaseModel):
# Generated tokens
tokens: List[Token]
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[List[Token]]
top_tokens: Optional[List[List[Token]]]
# `generate` details
@ -196,9 +201,8 @@ class Details(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[List[Token]]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
@ -226,8 +230,7 @@ class StreamResponse(BaseModel):
# Generated token
token: Token
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[Token]
top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]

View File

@ -239,6 +239,7 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
@ -254,6 +255,7 @@ pub(crate) struct Details {
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
@ -278,8 +280,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[schema(nullable = true, default = "null")]
pub top_tokens: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -198,6 +198,11 @@ async fn generate(
.collect()
});
// let top_tokens = match response.top_tokens.is_empty() {
// true => None,
// false => Some(response.top_tokens),
// };
Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
@ -376,12 +381,8 @@ async fn generate_stream(
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
<<<<<<< HEAD
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
=======
let top_n_tokens = req.0.parameters.top_n_tokens;
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
>>>>>>> 7c014c7 (Add WIP support for returning top tokens)
let top_n_tokens = req.parameters.top_n_tokens;
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream
@ -401,7 +402,7 @@ async fn generate_stream(
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens: top_n_tokens.and(Some(top_tokens)),
top_tokens: top_tokens,
generated_text: None,
details: None,
};
@ -463,7 +464,7 @@ async fn generate_stream(
let stream_token = StreamResponse {
token,
top_tokens:top_n_tokens.and(Some(top_tokens)),
top_tokens: top_tokens,
generated_text: Some(output_text),
details
};