mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
force sampling when using best_of
This commit is contained in:
parent
9f4f2fc8e3
commit
8d7a0c1992
@ -422,7 +422,7 @@ async fn generate_stream(
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
} else {
|
||||
let err = InferError::from(ValidationError::StreamBestOf);
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::validation::ValidationError::{EmptyInput, SeedBestOf};
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
@ -184,6 +184,18 @@ fn validate(
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
// sampling must be true when best_of > 1
|
||||
let best_of = best_of.unwrap_or(1);
|
||||
let sampling = do_sample
|
||||
|| temperature.is_some()
|
||||
|| top_k.is_some()
|
||||
|| top_p.is_some()
|
||||
|| typical_p.is_some();
|
||||
|
||||
if best_of > 1 && !sampling {
|
||||
return Err(BestOfSampling);
|
||||
}
|
||||
|
||||
let temperature = temperature.unwrap_or(1.0);
|
||||
if temperature <= 0.0 {
|
||||
return Err(ValidationError::Temperature);
|
||||
@ -238,8 +250,8 @@ fn validate(
|
||||
let seed = match seed {
|
||||
None => rng.gen(),
|
||||
Some(seed) => {
|
||||
if best_of.unwrap_or(1) > 1 {
|
||||
return Err(SeedBestOf);
|
||||
if best_of > 1 {
|
||||
return Err(BestOfSeed);
|
||||
}
|
||||
seed
|
||||
}
|
||||
@ -332,14 +344,16 @@ pub(crate) struct ValidGenerateRequest {
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||
BestOfDisabled,
|
||||
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
||||
BestOf(usize, usize),
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
StreamBestOf,
|
||||
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||
BestOfDisabled,
|
||||
#[error("you must use sampling when `best_of` is > 1")]
|
||||
BestOfSampling,
|
||||
#[error("`seed` must not be set when `best_of` > 1")]
|
||||
SeedBestOf,
|
||||
BestOfSeed,
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
BestOfStream,
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
|
Loading…
Reference in New Issue
Block a user