From b3e40c4b66da359451599f00d51a0eb55f181609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 5 Feb 2025 16:38:52 +0000 Subject: [PATCH] Improve default settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/main.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index d30d7f82..47f33430 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -77,8 +77,8 @@ struct Args { validation_workers: usize, /// Maximum amount of concurrent requests. - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, + #[clap(long, env)] + max_concurrent_requests: Option, /// Maximum number of input tokens per request. #[clap(default_value = "1024", long, env)] @@ -97,8 +97,8 @@ struct Args { max_physical_batch_total_tokens: Option, /// Maximum number of requests per batch. - #[clap(default_value = "1", long, env)] - max_batch_size: usize, + #[clap(long, env)] + max_batch_size: Option, /// IP address to listen on. #[clap(default_value = "0.0.0.0", long, env)] @@ -175,14 +175,22 @@ async fn main() -> Result<(), RouterError> { Some(0) | None => n_threads, Some(threads) => threads, }; + let max_batch_size = match args.max_batch_size { + Some(0) | None => n_threads_batch, + Some(threads) => threads, + }; let max_batch_total_tokens = match args.max_batch_total_tokens { - None => args.max_batch_size * args.max_total_tokens, + None => max_batch_size * args.max_total_tokens, Some(size) => size, }; let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens { None => max_batch_total_tokens, Some(size) => size, }; + let max_concurrent_requests = match args.max_concurrent_requests { + None => max_batch_size * 2, + Some(size) => size, + }; if args.max_input_tokens >= args.max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), @@ -193,7 +201,7 @@ async fn main() -> Result<(), RouterError> { "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); } - if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens { + if max_batch_size * args.max_total_tokens > max_batch_total_tokens { return Err(RouterError::ArgumentValidation( "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); @@ -233,7 +241,7 @@ async fn main() -> Result<(), RouterError> { offload_kqv: args.offload_kqv, max_batch_total_tokens: max_batch_total_tokens, max_physical_batch_total_tokens: max_physical_batch_total_tokens, - max_batch_size: args.max_batch_size, + max_batch_size: max_batch_size, batch_timeout: tokio::time::Duration::from_millis(5), }, tokenizer, @@ -250,7 +258,7 @@ async fn main() -> Result<(), RouterError> { server::run( backend, - args.max_concurrent_requests, + max_concurrent_requests, 0, // max_best_of 0, // max_stop_sequences 0, // max_top_n_tokens