From c8505fb300735da0f09eff06f4dcd9ecb62ad78c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 1 Feb 2025 18:33:26 +0000 Subject: [PATCH] Auto-detect n_threads when not provided MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- Cargo.lock | 1 + backends/llamacpp/Cargo.toml | 1 + backends/llamacpp/src/main.rs | 10 +++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 73ed43c6..902fe7e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4643,6 +4643,7 @@ dependencies = [ "async-trait", "bindgen 0.71.1", "clap 4.5.27", + "num_cpus", "pkg-config", "text-generation-router", "thiserror 2.0.11", diff --git a/backends/llamacpp/Cargo.toml b/backends/llamacpp/Cargo.toml index b1ff3c3f..18c2ed0a 100644 --- a/backends/llamacpp/Cargo.toml +++ b/backends/llamacpp/Cargo.toml @@ -12,6 +12,7 @@ pkg-config = "0.3.31" [dependencies] async-trait = "0.1.85" clap = "4.5.27" +num_cpus = "1.16.0" text-generation-router = { path = "../../router" } thiserror = "2.0.11" tokenizers.workspace = true diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index b5eec467..f3e81782 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -29,8 +29,8 @@ struct Args { n_ctx: usize, /// Number of threads to use for inference. - #[clap(default_value = "1", long, env)] - n_threads: usize, + #[clap(long, env)] + n_threads: Option, /// Number of layers to store in VRAM. #[clap(default_value = "0", long, env)] @@ -155,6 +155,10 @@ async fn main() -> Result<(), RouterError> { args.json_output ); + let n_threads = match args.n_threads { + Some(0) | None => num_cpus::get(), + Some(threads) => threads, + }; if args.max_input_tokens >= args.max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), @@ -197,7 +201,7 @@ async fn main() -> Result<(), RouterError> { LlamacppConfig { model_gguf: args.model_gguf, n_ctx: args.n_ctx, - n_threads: args.n_threads, + n_threads: n_threads, n_gpu_layers: args.n_gpu_layers, split_mode: args.split_mode, defrag_threshold: args.defrag_threshold,