diff --git a/router/src/infer.rs b/router/src/infer.rs index 3f48a3a7..5955faec 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -185,28 +185,35 @@ impl Infer { &self, request: GenerateRequest, best_of: usize, - ) -> Result { + ) -> Result<(InferResponse, Vec), InferError> { // validate best_of parameter separately let best_of = self.validation.validate_best_of(best_of)?; // create multiple generate requests - let infer_responses: Vec = + let mut infer_responses: Vec = try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; // get the sequence with the highest log probability per token + let mut max_index = 0; let mut max_logprob: f32 = f32::MIN; - let mut best_response = None; - for response in infer_responses { - // sum logprobs of the generated tokens - let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum(); + + for (i, response) in infer_responses.iter().enumerate() { + // mean logprobs of the generated tokens + let sequence_logprob = response + .tokens + .iter() + .map(|token| token.logprob) + .sum::() + / response.tokens.len() as f32; // set best sequence if sequence_logprob > max_logprob { + max_index = i; max_logprob = sequence_logprob; - best_response = Some(response); } } - Ok(best_response.expect("best_response is None. This is a bug.")) + let best_response = infer_responses.remove(max_index); + Ok((best_response, infer_responses)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 5a73577f..3873a3b5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -168,6 +168,20 @@ pub(crate) enum FinishReason { StopSequence, } +#[derive(Serialize, ToSchema)] +pub(crate) struct BestOfSequence { + #[schema(example = "test")] + pub generated_text: String, + #[schema(example = "length")] + pub finish_reason: FinishReason, + #[schema(example = 1)] + pub generated_tokens: u32, + #[schema(example = 42)] + pub seed: Option, + pub prefill: Vec, + pub tokens: Vec, +} + #[derive(Serialize, ToSchema)] pub(crate) struct Details { #[schema(example = "length")] @@ -178,6 +192,8 @@ pub(crate) struct Details { pub seed: Option, pub prefill: Vec, pub tokens: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of_sequences: Option>, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index e91be18c..3add4c7a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,10 +1,10 @@ /// HTTP Server logic -use crate::infer::{InferError, InferStreamResponse}; +use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, - GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, - Validation, + BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, + GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, + StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -87,21 +87,21 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json 1 => infer.generate_best_of(req.0, best_of).await?, - _ => infer.generate(req.0).await?, + let (response, best_of_responses) = match req.0.parameters.best_of { + Some(best_of) if best_of > 1 => { + let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + (response, Some(best_of_responses)) + } + _ => (infer.generate(req.0).await?, None), }; // Token details let details = match details { - true => Some(Details { - finish_reason: FinishReason::from(response.generated_text.finish_reason), - generated_tokens: response.generated_text.generated_tokens, - prefill: response.prefill, - tokens: response.tokens, - seed: response.generated_text.seed, - }), + true => { + // convert best_of_responses + let best_of_sequences = best_of_responses.map(|responses: Vec| { + responses + .into_iter() + .map(|response: InferResponse| { + // Add prompt if return_full_text + let mut output_text = response.generated_text.text; + if let Some(prompt) = &add_prompt { + output_text = prompt.clone() + &output_text; + } + + BestOfSequence { + generated_text: output_text, + finish_reason: FinishReason::from( + response.generated_text.finish_reason, + ), + generated_tokens: response.generated_text.generated_tokens, + prefill: response.prefill, + tokens: response.tokens, + seed: response.generated_text.seed, + } + }) + .collect() + }); + + Some(Details { + finish_reason: FinishReason::from(response.generated_text.finish_reason), + generated_tokens: response.generated_text.generated_tokens, + prefill: response.prefill, + tokens: response.tokens, + seed: response.generated_text.seed, + best_of_sequences, + }) + } false => None, }; @@ -222,26 +253,26 @@ async fn generate( /// Generate a stream of token using Server-Sent Events #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/generate_stream", - request_body = GenerateRequest, - responses( - (status = 200, description = "Generated Text", body = StreamResponse, - content_type = "text/event-stream"), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"}), - content_type = "text/event-stream"), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"}), - content_type = "text/event-stream"), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"}), - content_type = "text/event-stream"), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"}), - content_type = "text/event-stream"), - ) +post, +tag = "Text Generation Inference", +path = "/generate_stream", +request_body = GenerateRequest, +responses( +(status = 200, description = "Generated Text", body = StreamResponse, +content_type = "text/event-stream"), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"}), +content_type = "text/event-stream"), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"}), +content_type = "text/event-stream"), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"}), +content_type = "text/event-stream"), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"}), +content_type = "text/event-stream"), +) )] #[instrument( skip(infer), @@ -403,10 +434,10 @@ async fn generate_stream( /// Prometheus metrics scrape endpoint #[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/metrics", - responses((status = 200, description = "Prometheus Metrics", body = String)) +get, +tag = "Text Generation Inference", +path = "/metrics", +responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -432,35 +463,36 @@ pub async fn run( // OpenAPI documentation #[derive(OpenApi)] #[openapi( - paths( - generate, - generate_stream, - metrics, - ), - components( - schemas( - GenerateRequest, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) + paths( + generate, + generate_stream, + metrics, + ), + components( + schemas( + GenerateRequest, + GenerateParameters, + PrefillToken, + Token, + GenerateResponse, + BestOfSequence, + Details, + FinishReason, + StreamResponse, + StreamDetails, + ErrorResponse, + ) + ), + tags( + (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") + ), + info( + title = "Text Generation Inference", + license( + name = "Apache 2.0", + url = "https://www.apache.org/licenses/LICENSE-2.0" + ) + ) )] struct ApiDoc;