mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Make max_batch_total_tokens optional
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
09a745f1b8
commit
5b777877b1
@ -89,8 +89,8 @@ struct Args {
|
||||
max_total_tokens: usize,
|
||||
|
||||
/// Maximum number of tokens in a batch.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_total_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<usize>,
|
||||
|
||||
/// Maximum number of tokens in a physical batch.
|
||||
#[clap(long, env)]
|
||||
@ -175,8 +175,12 @@ async fn main() -> Result<(), RouterError> {
|
||||
Some(0) | None => n_threads,
|
||||
Some(threads) => threads,
|
||||
};
|
||||
let max_batch_total_tokens = match args.max_batch_total_tokens {
|
||||
None => args.max_batch_size * args.max_total_tokens,
|
||||
Some(size) => size,
|
||||
};
|
||||
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
|
||||
None => args.max_batch_total_tokens,
|
||||
None => max_batch_total_tokens,
|
||||
Some(size) => size,
|
||||
};
|
||||
if args.max_input_tokens >= args.max_total_tokens {
|
||||
@ -184,12 +188,12 @@ async fn main() -> Result<(), RouterError> {
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_total_tokens > args.max_batch_total_tokens {
|
||||
if args.max_total_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_batch_size * args.max_total_tokens > args.max_batch_total_tokens {
|
||||
if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||
));
|
||||
@ -227,7 +231,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
type_k: args.type_k,
|
||||
type_v: args.type_v,
|
||||
offload_kqv: args.offload_kqv,
|
||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||
max_batch_total_tokens: max_batch_total_tokens,
|
||||
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
|
||||
max_batch_size: args.max_batch_size,
|
||||
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||
|
Loading…
Reference in New Issue
Block a user