diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index d30d7f82..47f33430 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -77,8 +77,8 @@ struct Args { validation_workers: usize, /// Maximum amount of concurrent requests. - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, + #[clap(long, env)] + max_concurrent_requests: Option, /// Maximum number of input tokens per request. #[clap(default_value = "1024", long, env)] @@ -97,8 +97,8 @@ struct Args { max_physical_batch_total_tokens: Option, /// Maximum number of requests per batch. - #[clap(default_value = "1", long, env)] - max_batch_size: usize, + #[clap(long, env)] + max_batch_size: Option, /// IP address to listen on. #[clap(default_value = "0.0.0.0", long, env)] @@ -175,14 +175,22 @@ async fn main() -> Result<(), RouterError> { Some(0) | None => n_threads, Some(threads) => threads, }; + let max_batch_size = match args.max_batch_size { + Some(0) | None => n_threads_batch, + Some(threads) => threads, + }; let max_batch_total_tokens = match args.max_batch_total_tokens { - None => args.max_batch_size * args.max_total_tokens, + None => max_batch_size * args.max_total_tokens, Some(size) => size, }; let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens { None => max_batch_total_tokens, Some(size) => size, }; + let max_concurrent_requests = match args.max_concurrent_requests { + None => max_batch_size * 2, + Some(size) => size, + }; if args.max_input_tokens >= args.max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), @@ -193,7 +201,7 @@ async fn main() -> Result<(), RouterError> { "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); } - if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens { + if max_batch_size * args.max_total_tokens > max_batch_total_tokens { return Err(RouterError::ArgumentValidation( "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); @@ -233,7 +241,7 @@ async fn main() -> Result<(), RouterError> { offload_kqv: args.offload_kqv, max_batch_total_tokens: max_batch_total_tokens, max_physical_batch_total_tokens: max_physical_batch_total_tokens, - max_batch_size: args.max_batch_size, + max_batch_size: max_batch_size, batch_timeout: tokio::time::Duration::from_millis(5), }, tokenizer, @@ -250,7 +258,7 @@ async fn main() -> Result<(), RouterError> { server::run( backend, - args.max_concurrent_requests, + max_concurrent_requests, 0, // max_best_of 0, // max_stop_sequences 0, // max_top_n_tokens