Auto-detect n_threads when not provided

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-01 18:33:26 +00:00
parent 27534d8ee4
commit c8505fb300
No known key found for this signature in database
3 changed files with 9 additions and 3 deletions

1
Cargo.lock generated
View File

@ -4643,6 +4643,7 @@ dependencies = [
"async-trait", "async-trait",
"bindgen 0.71.1", "bindgen 0.71.1",
"clap 4.5.27", "clap 4.5.27",
"num_cpus",
"pkg-config", "pkg-config",
"text-generation-router", "text-generation-router",
"thiserror 2.0.11", "thiserror 2.0.11",

View File

@ -12,6 +12,7 @@ pkg-config = "0.3.31"
[dependencies] [dependencies]
async-trait = "0.1.85" async-trait = "0.1.85"
clap = "4.5.27" clap = "4.5.27"
num_cpus = "1.16.0"
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
thiserror = "2.0.11" thiserror = "2.0.11"
tokenizers.workspace = true tokenizers.workspace = true

View File

@ -29,8 +29,8 @@ struct Args {
n_ctx: usize, n_ctx: usize,
/// Number of threads to use for inference. /// Number of threads to use for inference.
#[clap(default_value = "1", long, env)] #[clap(long, env)]
n_threads: usize, n_threads: Option<usize>,
/// Number of layers to store in VRAM. /// Number of layers to store in VRAM.
#[clap(default_value = "0", long, env)] #[clap(default_value = "0", long, env)]
@ -155,6 +155,10 @@ async fn main() -> Result<(), RouterError> {
args.json_output args.json_output
); );
let n_threads = match args.n_threads {
Some(0) | None => num_cpus::get(),
Some(threads) => threads,
};
if args.max_input_tokens >= args.max_total_tokens { if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
@ -197,7 +201,7 @@ async fn main() -> Result<(), RouterError> {
LlamacppConfig { LlamacppConfig {
model_gguf: args.model_gguf, model_gguf: args.model_gguf,
n_ctx: args.n_ctx, n_ctx: args.n_ctx,
n_threads: args.n_threads, n_threads: n_threads,
n_gpu_layers: args.n_gpu_layers, n_gpu_layers: args.n_gpu_layers,
split_mode: args.split_mode, split_mode: args.split_mode,
defrag_threshold: args.defrag_threshold, defrag_threshold: args.defrag_threshold,