From 8d7a0c1992ffd565e20cea25be28a4c0639ba29b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Mar 2023 14:50:42 +0100 Subject: [PATCH] force sampling when using best_of --- router/src/server.rs | 2 +- router/src/validation.rs | 30 ++++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 3add4c7a..0983feea 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -422,7 +422,7 @@ async fn generate_stream( yield Ok(Event::from(err)); } } else { - let err = InferError::from(ValidationError::StreamBestOf); + let err = InferError::from(ValidationError::BestOfStream); metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); diff --git a/router/src/validation.rs b/router/src/validation.rs index 95a8eb9e..cb8dd0a2 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,4 @@ -use crate::validation::ValidationError::{EmptyInput, SeedBestOf}; +use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; @@ -184,6 +184,18 @@ fn validate( .. } = request.parameters; + // sampling must be true when best_of > 1 + let best_of = best_of.unwrap_or(1); + let sampling = do_sample + || temperature.is_some() + || top_k.is_some() + || top_p.is_some() + || typical_p.is_some(); + + if best_of > 1 && !sampling { + return Err(BestOfSampling); + } + let temperature = temperature.unwrap_or(1.0); if temperature <= 0.0 { return Err(ValidationError::Temperature); @@ -238,8 +250,8 @@ fn validate( let seed = match seed { None => rng.gen(), Some(seed) => { - if best_of.unwrap_or(1) > 1 { - return Err(SeedBestOf); + if best_of > 1 { + return Err(BestOfSeed); } seed } @@ -332,14 +344,16 @@ pub(crate) struct ValidGenerateRequest { #[derive(Error, Debug)] pub enum ValidationError { - #[error("`best_of` != 1 is not allowed for this endpoint")] - BestOfDisabled, #[error("`best_of` must be > 0 and <= {0}. Given: {1}")] BestOf(usize, usize), - #[error("`best_of` != 1 is not supported when streaming tokens")] - StreamBestOf, + #[error("`best_of` != 1 is not allowed for this endpoint")] + BestOfDisabled, + #[error("you must use sampling when `best_of` is > 1")] + BestOfSampling, #[error("`seed` must not be set when `best_of` > 1")] - SeedBestOf, + BestOfSeed, + #[error("`best_of` != 1 is not supported when streaming tokens")] + BestOfStream, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")]