mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
feat(router): add best_of parameter
This commit is contained in:
parent
e8bfe199ba
commit
5932ff4aa2
@ -31,6 +31,8 @@ struct Args {
|
||||
quantize: bool,
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
@ -86,6 +88,7 @@ fn main() -> ExitCode {
|
||||
num_shard,
|
||||
quantize,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
@ -363,6 +366,8 @@ fn main() -> ExitCode {
|
||||
"text-generation-router".to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
max_best_of.to_string(),
|
||||
"--max-stop-sequences".to_string(),
|
||||
max_stop_sequences.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
|
@ -2,6 +2,7 @@
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use futures::future::try_join_all;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
@ -177,6 +178,36 @@ impl Infer {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
/// Add a best_of new request to the queue and return a InferResponse of the sequence with
|
||||
/// the highest log probability per token
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn generate_best_of(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
best_of: usize,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
// validate best_of parameter separately
|
||||
let best_of = self.validation.validate_best_of(best_of)?;
|
||||
|
||||
// create multiple generate requests
|
||||
let infer_responses: Vec<InferResponse> =
|
||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||
|
||||
// get the sequence with the highest log probability per token
|
||||
let mut max_logprob: f32 = f32::MIN;
|
||||
let mut best_response = None;
|
||||
for response in infer_responses {
|
||||
// sum logprobs of the generated tokens
|
||||
let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum();
|
||||
|
||||
// set best sequence
|
||||
if sequence_logprob > max_logprob {
|
||||
max_logprob = sequence_logprob;
|
||||
best_response = Some(response);
|
||||
}
|
||||
}
|
||||
Ok(best_response.expect("best_response is None. This is a bug."))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
|
@ -12,6 +12,9 @@ use validation::Validation;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||
pub best_of: Option<usize>,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0.0,
|
||||
@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(default = "true")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0,
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = "null"
|
||||
)]
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 {
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
|
@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer};
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
server::run(
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
|
@ -1,5 +1,6 @@
|
||||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
|
||||
@ -64,6 +65,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
@ -128,7 +130,10 @@ async fn generate(
|
||||
let details = req.0.parameters.details;
|
||||
|
||||
// Inference
|
||||
let response = infer.generate(req.0).await?;
|
||||
let response = match req.0.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => infer.generate_best_of(req.0, best_of).await?,
|
||||
_ => infer.generate(req.0).await?,
|
||||
};
|
||||
|
||||
// Token details
|
||||
let details = match details {
|
||||
@ -279,107 +284,115 @@ async fn generate_stream(
|
||||
}
|
||||
let details = req.0.parameters.details;
|
||||
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
Ok(mut response_stream) => {
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
Ok(response) => {
|
||||
match response {
|
||||
// Prefill is ignored
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Token(token) => {
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||
if best_of == 1 {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
Ok(mut response_stream) => {
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
Ok(response) => {
|
||||
match response {
|
||||
// Prefill is ignored
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Token(token) => {
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(StreamDetails {
|
||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
||||
// Timings
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = queued - start_time;
|
||||
let queue_time = start - queued;
|
||||
let inference_time = Instant::now() - start;
|
||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||
|
||||
// Tracing metadata
|
||||
span.record("total_time", format!("{total_time:?}"));
|
||||
span.record("validation_time", format!("{validation_time:?}"));
|
||||
span.record("queue_time", format!("{queue_time:?}"));
|
||||
span.record("inference_time", format!("{inference_time:?}"));
|
||||
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time);
|
||||
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
|
||||
let mut output_text = generated_text.text;
|
||||
if let Some(prompt) = add_prompt {
|
||||
output_text = prompt + &output_text;
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||
}
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(StreamDetails {
|
||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
break;
|
||||
// Timings
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = queued - start_time;
|
||||
let queue_time = start - queued;
|
||||
let inference_time = Instant::now() - start;
|
||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||
|
||||
// Tracing metadata
|
||||
span.record("total_time", format!("{total_time:?}"));
|
||||
span.record("validation_time", format!("{validation_time:?}"));
|
||||
span.record("queue_time", format!("{queue_time:?}"));
|
||||
span.record("inference_time", format!("{inference_time:?}"));
|
||||
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time);
|
||||
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
|
||||
let mut output_text = generated_text.text;
|
||||
if let Some(prompt) = add_prompt {
|
||||
output_text = prompt + &output_text;
|
||||
}
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
break;
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
},
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
}
|
||||
// Check if generation reached the end
|
||||
// Skip if we already sent an error
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
}
|
||||
// Check if generation reached the end
|
||||
// Skip if we already sent an error
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
} else {
|
||||
let err = InferError::from(ValidationError::StreamBestOf);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
@ -404,6 +417,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
pub async fn run(
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
@ -454,6 +468,7 @@ pub async fn run(
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::validation::ValidationError::EmptyInput;
|
||||
use crate::validation::ValidationError::{EmptyInput, SeedBestOf};
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
@ -13,6 +13,9 @@ use tracing::{instrument, Span};
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
/// maximum value for the best_of parameter
|
||||
#[allow(dead_code)]
|
||||
max_best_of: usize,
|
||||
/// Channel to communicate with the background validation task
|
||||
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||
}
|
||||
@ -21,6 +24,7 @@ impl Validation {
|
||||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
@ -39,6 +43,7 @@ impl Validation {
|
||||
));
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender: validation_sender,
|
||||
}
|
||||
}
|
||||
@ -60,6 +65,20 @@ impl Validation {
|
||||
// Unwrap is safe here
|
||||
receiver.await.unwrap()
|
||||
}
|
||||
|
||||
/// Validate the best_of parameter
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
|
||||
if self.max_best_of == 1 && best_of != 1 {
|
||||
return Err(ValidationError::BestOfDisabled);
|
||||
}
|
||||
|
||||
if best_of > self.max_best_of {
|
||||
return Err(ValidationError::BestOf(self.max_best_of, best_of));
|
||||
}
|
||||
|
||||
Ok(best_of)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validation task
|
||||
@ -150,6 +169,7 @@ fn validate(
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
let GenerateParameters {
|
||||
best_of,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k,
|
||||
@ -217,7 +237,12 @@ fn validate(
|
||||
// If seed is None, assign a random one
|
||||
let seed = match seed {
|
||||
None => rng.gen(),
|
||||
Some(seed) => seed,
|
||||
Some(seed) => {
|
||||
if best_of.unwrap_or(1) > 1 {
|
||||
return Err(SeedBestOf);
|
||||
}
|
||||
seed
|
||||
}
|
||||
};
|
||||
|
||||
// Check if inputs is empty
|
||||
@ -307,6 +332,14 @@ pub(crate) struct ValidGenerateRequest {
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||
BestOfDisabled,
|
||||
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
||||
BestOf(usize, usize),
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
StreamBestOf,
|
||||
#[error("`seed` must not be set when `best_of` > 1")]
|
||||
SeedBestOf,
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
|
Loading…
Reference in New Issue
Block a user