mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Improve default settings
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
f22e2fb550
commit
b3e40c4b66
@ -77,8 +77,8 @@ struct Args {
|
||||
validation_workers: usize,
|
||||
|
||||
/// Maximum amount of concurrent requests.
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(long, env)]
|
||||
max_concurrent_requests: Option<usize>,
|
||||
|
||||
/// Maximum number of input tokens per request.
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
@ -97,8 +97,8 @@ struct Args {
|
||||
max_physical_batch_total_tokens: Option<usize>,
|
||||
|
||||
/// Maximum number of requests per batch.
|
||||
#[clap(default_value = "1", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// IP address to listen on.
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -175,14 +175,22 @@ async fn main() -> Result<(), RouterError> {
|
||||
Some(0) | None => n_threads,
|
||||
Some(threads) => threads,
|
||||
};
|
||||
let max_batch_size = match args.max_batch_size {
|
||||
Some(0) | None => n_threads_batch,
|
||||
Some(threads) => threads,
|
||||
};
|
||||
let max_batch_total_tokens = match args.max_batch_total_tokens {
|
||||
None => args.max_batch_size * args.max_total_tokens,
|
||||
None => max_batch_size * args.max_total_tokens,
|
||||
Some(size) => size,
|
||||
};
|
||||
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
|
||||
None => max_batch_total_tokens,
|
||||
Some(size) => size,
|
||||
};
|
||||
let max_concurrent_requests = match args.max_concurrent_requests {
|
||||
None => max_batch_size * 2,
|
||||
Some(size) => size,
|
||||
};
|
||||
if args.max_input_tokens >= args.max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
@ -193,7 +201,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_batch_size * args.max_total_tokens > max_batch_total_tokens {
|
||||
if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
|
||||
));
|
||||
@ -233,7 +241,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
offload_kqv: args.offload_kqv,
|
||||
max_batch_total_tokens: max_batch_total_tokens,
|
||||
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
|
||||
max_batch_size: args.max_batch_size,
|
||||
max_batch_size: max_batch_size,
|
||||
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||
},
|
||||
tokenizer,
|
||||
@ -250,7 +258,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
|
||||
server::run(
|
||||
backend,
|
||||
args.max_concurrent_requests,
|
||||
max_concurrent_requests,
|
||||
0, // max_best_of
|
||||
0, // max_stop_sequences
|
||||
0, // max_top_n_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user