diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 1865cf90..80466fe6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -31,6 +31,8 @@ struct Args { quantize: bool, #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "1000", long, env)] @@ -86,6 +88,7 @@ fn main() -> ExitCode { num_shard, quantize, max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, @@ -363,6 +366,8 @@ fn main() -> ExitCode { "text-generation-router".to_string(), "--max-concurrent-requests".to_string(), max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + max_best_of.to_string(), "--max-stop-sequences".to_string(), max_stop_sequences.to_string(), "--max-input-length".to_string(), diff --git a/router/src/infer.rs b/router/src/infer.rs index d0964f97..15bf682e 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,6 +2,7 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; +use futures::future::try_join_all; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_client::{ @@ -177,6 +178,36 @@ impl Infer { Err(err) } } + /// Add a best_of new request to the queue and return a InferResponse of the sequence with + /// the highest log probability per token + #[instrument(skip(self))] + pub(crate) async fn generate_best_of( + &self, + request: GenerateRequest, + best_of: usize, + ) -> Result { + // validate best_of parameter separately + let best_of = self.validation.validate_best_of(best_of)?; + + // create multiple generate requests + let infer_responses: Vec = + try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + + // get the sequence with the highest log probability per token + let mut max_logprob: f32 = f32::MIN; + let mut best_response = None; + for response in infer_responses { + // sum logprobs of the generated tokens + let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum(); + + // set best sequence + if sequence_logprob > max_logprob { + max_logprob = sequence_logprob; + best_response = Some(response); + } + } + Ok(best_response.expect("best_response is None. This is a bug.")) + } } /// Batching logic diff --git a/router/src/lib.rs b/router/src/lib.rs index 9fcc5085..5a73577f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -12,6 +12,9 @@ use validation::Validation; #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] + pub best_of: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, @@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub details: bool, #[serde(default)] + #[schema( + exclusive_minimum = 0, + nullable = true, + default = "null", + example = "null" + )] pub seed: Option, } @@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 { fn default_parameters() -> GenerateParameters { GenerateParameters { + best_of: None, temperature: None, repetition_penalty: None, top_k: None, diff --git a/router/src/main.rs b/router/src/main.rs index a51d3168..2ccf66b3 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer}; struct Args { #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "1000", long, env)] @@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> { // Pattern match configuration let Args { max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, @@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> { server::run( compat_return_full_text, max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index ef10b7b1..e91be18c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,5 +1,6 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; +use crate::validation::ValidationError; use crate::{ CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, @@ -64,6 +65,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json 1 => infer.generate_best_of(req.0, best_of).await?, + _ => infer.generate(req.0).await?, + }; // Token details let details = match details { @@ -279,107 +284,115 @@ async fn generate_stream( } let details = req.0.parameters.details; - match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { - Ok(mut response_stream) => { - // Server-Sent Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { - match response { - // Prefill is ignored - InferStreamResponse::Prefill(_) => {} - // Yield event for every new token - InferStreamResponse::Token(token) => { - // StreamResponse - let stream_token = StreamResponse { - token, - generated_text: None, - details: None, - }; + let best_of = req.0.parameters.best_of.unwrap_or(1); + if best_of == 1 { + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { + Ok(mut response_stream) => { + // Server-Sent Event stream + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + // StreamResponse + let stream_token = StreamResponse { + token, + generated_text: None, + details: None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()) - } - // Yield event for last token and compute timings - InferStreamResponse::End { - token, - generated_text, - start, - queued, - } => { - // Token details - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("validation_time", format!("{validation_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - span.record("time_per_token", format!("{time_per_token:?}")); - span.record("seed", format!("{:?}", generated_text.seed)); - tracing::info!(parent: &span, "Output: {}", generated_text.text); - - // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time); - metrics::histogram!("tgi_request_validation_duration", validation_time); - metrics::histogram!("tgi_request_queue_duration", queue_time); - metrics::histogram!("tgi_request_inference_duration", inference_time); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); - - // StreamResponse - end_reached = true; - - let mut output_text = generated_text.text; - if let Some(prompt) = add_prompt { - output_text = prompt + &output_text; + yield Ok(Event::default().json_data(stream_token).unwrap()) } - - let stream_token = StreamResponse { + // Yield event for last token and compute timings + InferStreamResponse::End { token, - generated_text: Some(output_text), - details - }; + generated_text, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()); - break; + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", generated_text.seed)); + tracing::info!(parent: &span, "Output: {}", generated_text.text); + + // Metrics + metrics::increment_counter!("tgi_request_success"); + metrics::histogram!("tgi_request_duration", total_time); + metrics::histogram!("tgi_request_validation_duration", validation_time); + metrics::histogram!("tgi_request_queue_duration", queue_time); + metrics::histogram!("tgi_request_inference_duration", inference_time); + metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); + metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + + // StreamResponse + end_reached = true; + + let mut output_text = generated_text.text; + if let Some(prompt) = add_prompt { + output_text = prompt + &output_text; + } + + let stream_token = StreamResponse { + token, + generated_text: Some(output_text), + details + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()); + break; + } } } - } - // yield error - Err(err) => { - error = true; - yield Ok(Event::from(err)); - break; + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + break; + } } } + }, + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); } - }, - // yield error - Err(err) => { - error = true; + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); yield Ok(Event::from(err)); } - } - // Check if generation reached the end - // Skip if we already sent an error - if !end_reached && !error { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + } else { + let err = InferError::from(ValidationError::StreamBestOf); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -404,6 +417,7 @@ async fn metrics(prom_handle: Extension) -> String { pub async fn run( compat_return_full_text: bool, max_concurrent_requests: usize, + max_best_of: usize, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -454,6 +468,7 @@ pub async fn run( let validation = Validation::new( validation_workers, tokenizer, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/validation.rs b/router/src/validation.rs index 42af0169..95a8eb9e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,4 @@ -use crate::validation::ValidationError::EmptyInput; +use crate::validation::ValidationError::{EmptyInput, SeedBestOf}; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; @@ -13,6 +13,9 @@ use tracing::{instrument, Span}; /// Validation #[derive(Debug, Clone)] pub struct Validation { + /// maximum value for the best_of parameter + #[allow(dead_code)] + max_best_of: usize, /// Channel to communicate with the background validation task sender: mpsc::UnboundedSender, } @@ -21,6 +24,7 @@ impl Validation { pub(crate) fn new( workers: usize, tokenizer: Tokenizer, + max_best_of: usize, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -39,6 +43,7 @@ impl Validation { )); Self { + max_best_of, sender: validation_sender, } } @@ -60,6 +65,20 @@ impl Validation { // Unwrap is safe here receiver.await.unwrap() } + + /// Validate the best_of parameter + #[instrument(skip_all)] + pub(crate) fn validate_best_of(&self, best_of: usize) -> Result { + if self.max_best_of == 1 && best_of != 1 { + return Err(ValidationError::BestOfDisabled); + } + + if best_of > self.max_best_of { + return Err(ValidationError::BestOf(self.max_best_of, best_of)); + } + + Ok(best_of) + } } /// Validation task @@ -150,6 +169,7 @@ fn validate( rng: &mut ThreadRng, ) -> Result { let GenerateParameters { + best_of, temperature, repetition_penalty, top_k, @@ -217,7 +237,12 @@ fn validate( // If seed is None, assign a random one let seed = match seed { None => rng.gen(), - Some(seed) => seed, + Some(seed) => { + if best_of.unwrap_or(1) > 1 { + return Err(SeedBestOf); + } + seed + } }; // Check if inputs is empty @@ -307,6 +332,14 @@ pub(crate) struct ValidGenerateRequest { #[derive(Error, Debug)] pub enum ValidationError { + #[error("`best_of` != 1 is not allowed for this endpoint")] + BestOfDisabled, + #[error("`best_of` must be > 0 and <= {0}. Given: {1}")] + BestOf(usize, usize), + #[error("`best_of` != 1 is not supported when streaming tokens")] + StreamBestOf, + #[error("`seed` must not be set when `best_of` > 1")] + SeedBestOf, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")]