mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Make max_batch_total_tokens optional
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
09a745f1b8
commit
5b777877b1
@ -89,8 +89,8 @@ struct Args {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
|
||||||
/// Maximum number of tokens in a batch.
|
/// Maximum number of tokens in a batch.
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_total_tokens: usize,
|
max_batch_total_tokens: Option<usize>,
|
||||||
|
|
||||||
/// Maximum number of tokens in a physical batch.
|
/// Maximum number of tokens in a physical batch.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -175,8 +175,12 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
Some(0) | None => n_threads,
|
Some(0) | None => n_threads,
|
||||||
Some(threads) => threads,
|
Some(threads) => threads,
|
||||||
};
|
};
|
||||||
|
let max_batch_total_tokens = match args.max_batch_total_tokens {
|
||||||
|
None => args.max_batch_size * args.max_total_tokens,
|
||||||
|
Some(size) => size,
|
||||||
|
};
|
||||||
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
|
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
|
||||||
None => args.max_batch_total_tokens,
|
None => max_batch_total_tokens,
|
||||||
Some(size) => size,
|
Some(size) => size,
|
||||||
};
|
};
|
||||||
if args.max_input_tokens >= args.max_total_tokens {
|
if args.max_input_tokens >= args.max_total_tokens {
|
||||||
@ -184,12 +188,12 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if args.max_total_tokens > args.max_batch_total_tokens {
|
if args.max_total_tokens > max_batch_total_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
|
if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
@ -227,7 +231,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
type_k: args.type_k,
|
type_k: args.type_k,
|
||||||
type_v: args.type_v,
|
type_v: args.type_v,
|
||||||
offload_kqv: args.offload_kqv,
|
offload_kqv: args.offload_kqv,
|
||||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
max_batch_total_tokens: max_batch_total_tokens,
|
||||||
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
|
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
|
||||||
max_batch_size: args.max_batch_size,
|
max_batch_size: args.max_batch_size,
|
||||||
batch_timeout: tokio::time::Duration::from_millis(5),
|
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||||
|
Loading…
Reference in New Issue
Block a user