force sampling when using best_of

This commit is contained in:
OlivierDehaene 2023-03-09 14:50:42 +01:00
parent 9f4f2fc8e3
commit 8d7a0c1992
2 changed files with 23 additions and 9 deletions

View File

@ -422,7 +422,7 @@ async fn generate_stream(
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
} else { } else {
let err = InferError::from(ValidationError::StreamBestOf); let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));

View File

@ -1,4 +1,4 @@
use crate::validation::ValidationError::{EmptyInput, SeedBestOf}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic /// Payload validation logic
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
@ -184,6 +184,18 @@ fn validate(
.. ..
} = request.parameters; } = request.parameters;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0); let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 { if temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
@ -238,8 +250,8 @@ fn validate(
let seed = match seed { let seed = match seed {
None => rng.gen(), None => rng.gen(),
Some(seed) => { Some(seed) => {
if best_of.unwrap_or(1) > 1 { if best_of > 1 {
return Err(SeedBestOf); return Err(BestOfSeed);
} }
seed seed
} }
@ -332,14 +344,16 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")] #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
BestOf(usize, usize), BestOf(usize, usize),
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not allowed for this endpoint")]
StreamBestOf, BestOfDisabled,
#[error("you must use sampling when `best_of` is > 1")]
BestOfSampling,
#[error("`seed` must not be set when `best_of` > 1")] #[error("`seed` must not be set when `best_of` > 1")]
SeedBestOf, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]