diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 310ca8f1..d30d7f82 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -89,8 +89,8 @@ struct Args { max_total_tokens: usize, /// Maximum number of tokens in a batch. - #[clap(default_value = "4096", long, env)] - max_batch_total_tokens: usize, + #[clap(long, env)] + max_batch_total_tokens: Option, /// Maximum number of tokens in a physical batch. #[clap(long, env)] @@ -175,8 +175,12 @@ async fn main() -> Result<(), RouterError> { Some(0) | None => n_threads, Some(threads) => threads, }; + let max_batch_total_tokens = match args.max_batch_total_tokens { + None => args.max_batch_size * args.max_total_tokens, + Some(size) => size, + }; let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens { - None => args.max_batch_total_tokens, + None => max_batch_total_tokens, Some(size) => size, }; if args.max_input_tokens >= args.max_total_tokens { @@ -184,12 +188,12 @@ async fn main() -> Result<(), RouterError> { "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } - if args.max_total_tokens > args.max_batch_total_tokens { + if args.max_total_tokens > max_batch_total_tokens { return Err(RouterError::ArgumentValidation( "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); } - if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens { + if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens { return Err(RouterError::ArgumentValidation( "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(), )); @@ -227,7 +231,7 @@ async fn main() -> Result<(), RouterError> { type_k: args.type_k, type_v: args.type_v, offload_kqv: args.offload_kqv, - max_batch_total_tokens: args.max_batch_total_tokens, + max_batch_total_tokens: max_batch_total_tokens, max_physical_batch_total_tokens: max_physical_batch_total_tokens, max_batch_size: args.max_batch_size, batch_timeout: tokio::time::Duration::from_millis(5),