diff --git a/router/src/main.rs b/router/src/main.rs index 33e9a8ff..5ababa4b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -20,14 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer}; struct Args { #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, - #[clap(default_value = "512", long, env)] - max_max_new_tokens: u32, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, - #[clap(default_value = "1512", long, env)] - max_total_tokens: usize, #[clap(default_value = "1000", long, env)] max_input_length: usize, + #[clap(default_value = "1512", long, env)] + max_total_tokens: usize, #[clap(default_value = "32", long, env)] max_batch_size: usize, #[clap(default_value = "20", long, env)] @@ -52,7 +50,6 @@ fn main() -> Result<(), std::io::Error> { // Pattern match configuration let Args { max_concurrent_requests, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -101,7 +98,6 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( max_concurrent_requests, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index dd1ca9d8..19af1e78 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -291,7 +291,6 @@ async fn generate_stream( #[allow(clippy::too_many_arguments)] pub async fn run( max_concurrent_requests: usize, - max_max_new_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -339,7 +338,6 @@ pub async fn run( let validation = Validation::new( validation_workers, tokenizer, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/validation.rs b/router/src/validation.rs index 664ed5f0..50d090cd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -20,7 +20,6 @@ impl Validation { pub(crate) fn new( workers: usize, tokenizer: Tokenizer, - max_max_new_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -32,7 +31,6 @@ impl Validation { tokio::spawn(validation_task( workers, tokenizer, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -69,7 +67,6 @@ impl Validation { async fn validation_task( workers: usize, tokenizer: Tokenizer, - max_max_new_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -88,7 +85,6 @@ async fn validation_task( tokio::task::spawn_blocking(move || { validation_worker( tokenizer_clone, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -113,7 +109,6 @@ async fn validation_task( /// the tokenizer fn validation_worker( tokenizer: Tokenizer, - max_max_new_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -130,7 +125,6 @@ fn validation_worker( validate( request, &tokenizer, - max_max_new_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -149,7 +143,6 @@ fn validation_worker( fn validate( request: GenerateRequest, tokenizer: &Tokenizer, - max_max_new_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -194,8 +187,8 @@ fn validate( } }?; - if max_new_tokens == 0 || max_new_tokens > max_max_new_tokens { - return Err(ValidationError::MaxNewTokens(max_max_new_tokens)); + if max_new_tokens == 0 { + return Err(ValidationError::MaxNewTokens); } if stop_sequences.len() > max_stop_sequences { @@ -226,7 +219,8 @@ fn validate( } else if total_tokens > max_total_tokens { Err(ValidationError::MaxTotalTokens( max_total_tokens, - total_tokens, + input_length, + max_new_tokens, )) } else { // Return ValidGenerateRequest @@ -279,10 +273,10 @@ pub enum ValidationError { TopP, #[error("top_k must be strictly positive")] TopK, - #[error("max_new_tokens must be strictly positive and <= {0}")] - MaxNewTokens(u32), - #[error("input tokens + max_new_tokens must be <= {0}. Given {1}")] - MaxTotalTokens(usize, usize), + #[error("max_new_tokens must be strictly positive")] + MaxNewTokens, + #[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")] + MaxTotalTokens(usize, usize, u32), #[error("inputs must have less than {0} tokens. Given: {1}")] InputLength(usize, usize), #[error("inputs cannot be empty")]