mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Add max_top_n_tokens CLI argument
This commit is contained in:
parent
8471e1862d
commit
65c7b6207c
@ -130,6 +130,14 @@ struct Args {
|
|||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
|
||||||
|
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
||||||
|
/// `top_n_tokens is used to return information about the the `n` most likely
|
||||||
|
/// tokens at each generation step, instead of just the sampled token. This
|
||||||
|
/// information can be used for downstream tasks like for classification or
|
||||||
|
/// ranking.
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
|
||||||
/// This is the maximum allowed input length (expressed in number of tokens)
|
/// This is the maximum allowed input length (expressed in number of tokens)
|
||||||
/// for users. The larger this value, the longer prompt users can send which
|
/// for users. The larger this value, the longer prompt users can send which
|
||||||
/// can impact the overall memory required to handle the load.
|
/// can impact the overall memory required to handle the load.
|
||||||
@ -854,6 +862,8 @@ fn spawn_webserver(
|
|||||||
args.max_best_of.to_string(),
|
args.max_best_of.to_string(),
|
||||||
"--max-stop-sequences".to_string(),
|
"--max-stop-sequences".to_string(),
|
||||||
args.max_stop_sequences.to_string(),
|
args.max_stop_sequences.to_string(),
|
||||||
|
"--max-top-n-tokens".to_string(),
|
||||||
|
args.max_top_n_tokens.to_string(),
|
||||||
"--max-input-length".to_string(),
|
"--max-input-length".to_string(),
|
||||||
args.max_input_length.to_string(),
|
args.max_input_length.to_string(),
|
||||||
"--max-total-tokens".to_string(),
|
"--max-total-tokens".to_string(),
|
||||||
|
@ -29,6 +29,8 @@ struct Args {
|
|||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
#[clap(default_value = "1024", long, env)]
|
#[clap(default_value = "1024", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "2048", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
@ -255,6 +258,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
@ -235,7 +235,6 @@ impl State {
|
|||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
// TODO: Actually fill this from the request
|
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
|
||||||
});
|
});
|
||||||
|
@ -520,6 +520,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
@ -582,6 +583,7 @@ pub async fn run(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
|
@ -15,6 +15,7 @@ pub struct Validation {
|
|||||||
/// Validation parameters
|
/// Validation parameters
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
@ -27,6 +28,7 @@ impl Validation {
|
|||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -54,6 +56,7 @@ impl Validation {
|
|||||||
max_best_of,
|
max_best_of,
|
||||||
sender,
|
sender,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
}
|
}
|
||||||
@ -142,7 +145,6 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
// TODO: Validate top_n_tokens
|
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
@ -220,6 +222,15 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let top_n_tokens = top_n_tokens
|
||||||
|
.map(|value| {
|
||||||
|
if value > self.max_top_n_tokens {
|
||||||
|
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
// Check if inputs is empty
|
// Check if inputs is empty
|
||||||
if request.inputs.is_empty() {
|
if request.inputs.is_empty() {
|
||||||
return Err(EmptyInput);
|
return Err(EmptyInput);
|
||||||
@ -265,7 +276,7 @@ impl Validation {
|
|||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,6 +365,10 @@ pub enum ValidationError {
|
|||||||
BestOfSeed,
|
BestOfSeed,
|
||||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
BestOfStream,
|
BestOfStream,
|
||||||
|
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
|
||||||
|
TopNTokens(u32, u32),
|
||||||
|
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
|
||||||
|
TopNTokensDisabled,
|
||||||
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
||||||
PrefillDetailsStream,
|
PrefillDetailsStream,
|
||||||
#[error("`temperature` must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
@ -395,14 +410,16 @@ mod tests {
|
|||||||
let tokenizer = None;
|
let tokenizer = None;
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -422,14 +439,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -439,7 +458,7 @@ mod tests {
|
|||||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens"),
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -449,14 +468,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -481,14 +502,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -535,4 +558,75 @@ mod tests {
|
|||||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_top_n_tokens() {
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequences = 3;
|
||||||
|
let max_top_n_tokens = 4;
|
||||||
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(5),
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||||
|
_ => panic!("Unexpected top_n_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(4),
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(0),
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let valid_request = validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: None,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(valid_request.top_n_tokens, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user