mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
force sampling when using best_of
This commit is contained in:
parent
9f4f2fc8e3
commit
8d7a0c1992
@ -422,7 +422,7 @@ async fn generate_stream(
|
|||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::from(ValidationError::StreamBestOf);
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::validation::ValidationError::{EmptyInput, SeedBestOf};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
@ -184,6 +184,18 @@ fn validate(
|
|||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
// sampling must be true when best_of > 1
|
||||||
|
let best_of = best_of.unwrap_or(1);
|
||||||
|
let sampling = do_sample
|
||||||
|
|| temperature.is_some()
|
||||||
|
|| top_k.is_some()
|
||||||
|
|| top_p.is_some()
|
||||||
|
|| typical_p.is_some();
|
||||||
|
|
||||||
|
if best_of > 1 && !sampling {
|
||||||
|
return Err(BestOfSampling);
|
||||||
|
}
|
||||||
|
|
||||||
let temperature = temperature.unwrap_or(1.0);
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
if temperature <= 0.0 {
|
if temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
@ -238,8 +250,8 @@ fn validate(
|
|||||||
let seed = match seed {
|
let seed = match seed {
|
||||||
None => rng.gen(),
|
None => rng.gen(),
|
||||||
Some(seed) => {
|
Some(seed) => {
|
||||||
if best_of.unwrap_or(1) > 1 {
|
if best_of > 1 {
|
||||||
return Err(SeedBestOf);
|
return Err(BestOfSeed);
|
||||||
}
|
}
|
||||||
seed
|
seed
|
||||||
}
|
}
|
||||||
@ -332,14 +344,16 @@ pub(crate) struct ValidGenerateRequest {
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
|
||||||
BestOfDisabled,
|
|
||||||
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
||||||
BestOf(usize, usize),
|
BestOf(usize, usize),
|
||||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||||
StreamBestOf,
|
BestOfDisabled,
|
||||||
|
#[error("you must use sampling when `best_of` is > 1")]
|
||||||
|
BestOfSampling,
|
||||||
#[error("`seed` must not be set when `best_of` > 1")]
|
#[error("`seed` must not be set when `best_of` > 1")]
|
||||||
SeedBestOf,
|
BestOfSeed,
|
||||||
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
|
BestOfStream,
|
||||||
#[error("`temperature` must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("`repetition_penalty` must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
|
Loading…
Reference in New Issue
Block a user