Make max_batch_total_tokens optional

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-05 11:40:20 +00:00
parent 09a745f1b8
commit 5b777877b1
No known key found for this signature in database

View File

@ -89,8 +89,8 @@ struct Args {
max_total_tokens: usize,
/// Maximum number of tokens in a batch.
#[clap(default_value = "4096", long, env)]
max_batch_total_tokens: usize,
#[clap(long, env)]
max_batch_total_tokens: Option<usize>,
/// Maximum number of tokens in a physical batch.
#[clap(long, env)]
@ -175,8 +175,12 @@ async fn main() -> Result<(), RouterError> {
Some(0) | None => n_threads,
Some(threads) => threads,
};
let max_batch_total_tokens = match args.max_batch_total_tokens {
None => args.max_batch_size * args.max_total_tokens,
Some(size) => size,
};
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
None => args.max_batch_total_tokens,
None => max_batch_total_tokens,
Some(size) => size,
};
if args.max_input_tokens >= args.max_total_tokens {
@ -184,12 +188,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if args.max_total_tokens > args.max_batch_total_tokens {
if args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
@ -227,7 +231,7 @@ async fn main() -> Result<(), RouterError> {
type_k: args.type_k,
type_v: args.type_v,
offload_kqv: args.offload_kqv,
max_batch_total_tokens: args.max_batch_total_tokens,
max_batch_total_tokens: max_batch_total_tokens,
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
max_batch_size: args.max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5),