mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Auto-detect n_threads when not provided
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
27534d8ee4
commit
c8505fb300
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4643,6 +4643,7 @@ dependencies = [
|
|||||||
"async-trait",
|
"async-trait",
|
||||||
"bindgen 0.71.1",
|
"bindgen 0.71.1",
|
||||||
"clap 4.5.27",
|
"clap 4.5.27",
|
||||||
|
"num_cpus",
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
|
@ -12,6 +12,7 @@ pkg-config = "0.3.31"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.85"
|
async-trait = "0.1.85"
|
||||||
clap = "4.5.27"
|
clap = "4.5.27"
|
||||||
|
num_cpus = "1.16.0"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
thiserror = "2.0.11"
|
thiserror = "2.0.11"
|
||||||
tokenizers.workspace = true
|
tokenizers.workspace = true
|
||||||
|
@ -29,8 +29,8 @@ struct Args {
|
|||||||
n_ctx: usize,
|
n_ctx: usize,
|
||||||
|
|
||||||
/// Number of threads to use for inference.
|
/// Number of threads to use for inference.
|
||||||
#[clap(default_value = "1", long, env)]
|
#[clap(long, env)]
|
||||||
n_threads: usize,
|
n_threads: Option<usize>,
|
||||||
|
|
||||||
/// Number of layers to store in VRAM.
|
/// Number of layers to store in VRAM.
|
||||||
#[clap(default_value = "0", long, env)]
|
#[clap(default_value = "0", long, env)]
|
||||||
@ -155,6 +155,10 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
args.json_output
|
args.json_output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let n_threads = match args.n_threads {
|
||||||
|
Some(0) | None => num_cpus::get(),
|
||||||
|
Some(threads) => threads,
|
||||||
|
};
|
||||||
if args.max_input_tokens >= args.max_total_tokens {
|
if args.max_input_tokens >= args.max_total_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
@ -197,7 +201,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
LlamacppConfig {
|
LlamacppConfig {
|
||||||
model_gguf: args.model_gguf,
|
model_gguf: args.model_gguf,
|
||||||
n_ctx: args.n_ctx,
|
n_ctx: args.n_ctx,
|
||||||
n_threads: args.n_threads,
|
n_threads: n_threads,
|
||||||
n_gpu_layers: args.n_gpu_layers,
|
n_gpu_layers: args.n_gpu_layers,
|
||||||
split_mode: args.split_mode,
|
split_mode: args.split_mode,
|
||||||
defrag_threshold: args.defrag_threshold,
|
defrag_threshold: args.defrag_threshold,
|
||||||
|
Loading…
Reference in New Issue
Block a user