feat(router): add best_of parameter

This commit is contained in:
OlivierDehaene 2023-03-09 13:06:12 +01:00
parent e8bfe199ba
commit 5932ff4aa2
6 changed files with 191 additions and 93 deletions

View File

@ -31,6 +31,8 @@ struct Args {
quantize: bool,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)]
@ -86,6 +88,7 @@ fn main() -> ExitCode {
num_shard,
quantize,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
@ -363,6 +366,8 @@ fn main() -> ExitCode {
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
max_best_of.to_string(),
"--max-stop-sequences".to_string(),
max_stop_sequences.to_string(),
"--max-input-length".to_string(),

View File

@ -2,6 +2,7 @@
use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_client::{
@ -177,6 +178,36 @@ impl Infer {
Err(err)
}
}
/// Add a best_of new request to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<InferResponse, InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_logprob: f32 = f32::MIN;
let mut best_response = None;
for response in infer_responses {
// sum logprobs of the generated tokens
let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum();
// set best sequence
if sequence_logprob > max_logprob {
max_logprob = sequence_logprob;
best_response = Some(response);
}
}
Ok(best_response.expect("best_response is None. This is a bug."))
}
}
/// Batching logic

View File

@ -12,6 +12,9 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")]
pub details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>,
}
@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters {
GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,

View File

@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)]
@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> {
server::run(
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,

View File

@ -1,5 +1,6 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
@ -64,6 +65,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
@ -128,7 +130,10 @@ async fn generate(
let details = req.0.parameters.details;
// Inference
let response = infer.generate(req.0).await?;
let response = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => infer.generate_best_of(req.0, best_of).await?,
_ => infer.generate(req.0).await?,
};
// Token details
let details = match details {
@ -279,6 +284,8 @@ async fn generate_stream(
}
let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Server-Sent Event stream
@ -383,6 +390,12 @@ async fn generate_stream(
tracing::error!("{err}");
yield Ok(Event::from(err));
}
} else {
let err = InferError::from(ValidationError::StreamBestOf);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
}
};
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
@ -404,6 +417,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
pub async fn run(
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
@ -454,6 +468,7 @@ pub async fn run(
let validation = Validation::new(
validation_workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,

View File

@ -1,4 +1,4 @@
use crate::validation::ValidationError::EmptyInput;
use crate::validation::ValidationError::{EmptyInput, SeedBestOf};
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng;
@ -13,6 +13,9 @@ use tracing::{instrument, Span};
/// Validation
#[derive(Debug, Clone)]
pub struct Validation {
/// maximum value for the best_of parameter
#[allow(dead_code)]
max_best_of: usize,
/// Channel to communicate with the background validation task
sender: mpsc::UnboundedSender<ValidationRequest>,
}
@ -21,6 +24,7 @@ impl Validation {
pub(crate) fn new(
workers: usize,
tokenizer: Tokenizer,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
@ -39,6 +43,7 @@ impl Validation {
));
Self {
max_best_of,
sender: validation_sender,
}
}
@ -60,6 +65,20 @@ impl Validation {
// Unwrap is safe here
receiver.await.unwrap()
}
/// Validate the best_of parameter
#[instrument(skip_all)]
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
if self.max_best_of == 1 && best_of != 1 {
return Err(ValidationError::BestOfDisabled);
}
if best_of > self.max_best_of {
return Err(ValidationError::BestOf(self.max_best_of, best_of));
}
Ok(best_of)
}
}
/// Validation task
@ -150,6 +169,7 @@ fn validate(
rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters {
best_of,
temperature,
repetition_penalty,
top_k,
@ -217,7 +237,12 @@ fn validate(
// If seed is None, assign a random one
let seed = match seed {
None => rng.gen(),
Some(seed) => seed,
Some(seed) => {
if best_of.unwrap_or(1) > 1 {
return Err(SeedBestOf);
}
seed
}
};
// Check if inputs is empty
@ -307,6 +332,14 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
BestOf(usize, usize),
#[error("`best_of` != 1 is not supported when streaming tokens")]
StreamBestOf,
#[error("`seed` must not be set when `best_of` > 1")]
SeedBestOf,
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("`repetition_penalty` must be strictly positive")]