Improve default settings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-05 16:38:52 +00:00
parent f22e2fb550
commit b3e40c4b66
No known key found for this signature in database

View File

@ -77,8 +77,8 @@ struct Args {
validation_workers: usize,
/// Maximum amount of concurrent requests.
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(long, env)]
max_concurrent_requests: Option<usize>,
/// Maximum number of input tokens per request.
#[clap(default_value = "1024", long, env)]
@ -97,8 +97,8 @@ struct Args {
max_physical_batch_total_tokens: Option<usize>,
/// Maximum number of requests per batch.
#[clap(default_value = "1", long, env)]
max_batch_size: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
/// IP address to listen on.
#[clap(default_value = "0.0.0.0", long, env)]
@ -175,14 +175,22 @@ async fn main() -> Result<(), RouterError> {
Some(0) | None => n_threads,
Some(threads) => threads,
};
let max_batch_size = match args.max_batch_size {
Some(0) | None => n_threads_batch,
Some(threads) => threads,
};
let max_batch_total_tokens = match args.max_batch_total_tokens {
None => args.max_batch_size * args.max_total_tokens,
None => max_batch_size * args.max_total_tokens,
Some(size) => size,
};
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
None => max_batch_total_tokens,
Some(size) => size,
};
let max_concurrent_requests = match args.max_concurrent_requests {
None => max_batch_size * 2,
Some(size) => size,
};
if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
@ -193,7 +201,7 @@ async fn main() -> Result<(), RouterError> {
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens {
if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
@ -233,7 +241,7 @@ async fn main() -> Result<(), RouterError> {
offload_kqv: args.offload_kqv,
max_batch_total_tokens: max_batch_total_tokens,
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
max_batch_size: args.max_batch_size,
max_batch_size: max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5),
},
tokenizer,
@ -250,7 +258,7 @@ async fn main() -> Result<(), RouterError> {
server::run(
backend,
args.max_concurrent_requests,
max_concurrent_requests,
0, // max_best_of
0, // max_stop_sequences
0, // max_top_n_tokens