diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47..015613c2 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -75,6 +75,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text @@ -113,6 +114,8 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Response: generated response @@ -134,6 +137,7 @@ class Client: typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, + top_n_tokens=top_n_tokens ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -164,6 +168,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_n_tokens: Optional[int] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -198,6 +203,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Iterator[StreamResponse]: stream of generated tokens @@ -219,6 +226,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -317,6 +325,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -355,6 +364,8 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Response: generated response @@ -376,6 +387,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -404,6 +416,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_n_tokens: Optional[int] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -438,6 +451,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -459,6 +474,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 52fdaf53..38f75253 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -39,6 +39,8 @@ class Parameters(BaseModel): details: bool = False # Get decoder input token logprobs and ids decoder_input_details: bool = False + # Return the N most likely tokens at each step + top_n_tokens: Optional[int] @validator("best_of") def valid_best_of(cls, field_value, values): @@ -101,6 +103,12 @@ class Parameters(BaseModel): raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v + @validator("top_n_tokens") + def valid_top_n_tokens(cls, v): + if v is not None and v <= 0: + raise ValidationError("`top_n_tokens` must be strictly positive") + return v + class Request(BaseModel): # Prompt @@ -125,9 +133,7 @@ class Request(BaseModel): and parameters.best_of > 1 and field_value ): - raise ValidationError( - "`best_of` != 1 is not supported when `stream` == True" - ) + raise ValidationError("`best_of` != 1 is not supported when `stream` == True") return field_value @@ -180,8 +186,7 @@ class BestOfSequence(BaseModel): # Generated tokens tokens: List[Token] # Most likely tokens - # TODO: Make this optional? - top_tokens: List[List[Token]] + top_tokens: Optional[List[List[Token]]] # `generate` details @@ -196,9 +201,8 @@ class Details(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] - # Most likely tokens - # TODO: Make this optional? - top_tokens: List[List[Token]] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] @@ -226,8 +230,7 @@ class StreamResponse(BaseModel): # Generated token token: Token # Most likely tokens - # TODO: Make this optional? - top_tokens: List[Token] + top_tokens: Optional[List[Token]] # Complete generated text # Only available when the generation is finished generated_text: Optional[str] diff --git a/router/src/lib.rs b/router/src/lib.rs index 6f1d4c8f..76e70bb7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -239,6 +239,7 @@ pub(crate) struct BestOfSequence { pub seed: Option, pub prefill: Vec, pub tokens: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec>, } @@ -254,6 +255,7 @@ pub(crate) struct Details { pub tokens: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub best_of_sequences: Option>, + #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec>, } @@ -278,8 +280,8 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, - #[schema(nullable = true, default = "null")] - pub top_tokens: Option>, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/server.rs b/router/src/server.rs index d698ed99..16dd87bc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -198,6 +198,11 @@ async fn generate( .collect() }); + // let top_tokens = match response.top_tokens.is_empty() { + // true => None, + // false => Some(response.top_tokens), + // }; + Some(Details { finish_reason: FinishReason::from(response.generated_text.finish_reason), generated_tokens: response.generated_text.generated_tokens, @@ -376,12 +381,8 @@ async fn generate_stream( tracing::error!("{err}"); yield Ok(Event::from(err)); } else { -<<<<<<< HEAD match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { -======= - let top_n_tokens = req.0.parameters.top_n_tokens; - match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { ->>>>>>> 7c014c7 (Add WIP support for returning top tokens) + let top_n_tokens = req.parameters.top_n_tokens; // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { // Server-Sent Event stream @@ -401,7 +402,7 @@ async fn generate_stream( // StreamResponse let stream_token = StreamResponse { token, - top_tokens: top_n_tokens.and(Some(top_tokens)), + top_tokens: top_tokens, generated_text: None, details: None, }; @@ -463,7 +464,7 @@ async fn generate_stream( let stream_token = StreamResponse { token, - top_tokens:top_n_tokens.and(Some(top_tokens)), + top_tokens: top_tokens, generated_text: Some(output_text), details };