Handle max_batch_size

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-30 19:50:09 +00:00
parent 3eb4823f3e
commit e7facf692f
No known key found for this signature in database
2 changed files with 17 additions and 4 deletions

View File

@ -25,6 +25,7 @@ pub struct LlamacppConfig {
pub model_gguf: String,
pub n_ctx: u32,
pub max_batch_total_tokens: u32,
pub max_batch_size: Option<usize>,
pub batch_timeout: Duration,
pub n_threads: i32,
pub use_mmap: bool,
@ -320,13 +321,22 @@ impl LlamacppBackend {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed
Ok(Some(request)) => {
if let Some(max_batch_size) = conf.max_batch_size {
if requests.len() + 1 == max_batch_size {
requests.push(request);
let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new();
continue;
}
}
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
let _ = sync_tx.send(requests);
n_tokens = request.input_ids.len();
requests = vec![request];
} else {
requests.push(request);
continue;
}
requests.push(request);
},
Err(_) => {
if !requests.is_empty() {

View File

@ -75,8 +75,10 @@ struct Args {
// #[clap(default_value = "20", long, env)]
// max_waiting_tokens: usize,
// #[clap(long, env)]
// max_batch_size: Option<usize>,
/// Maximum number of requests per batch
#[clap(long, env)]
max_batch_size: Option<usize>,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
@ -165,6 +167,7 @@ async fn main() -> Result<(), RouterError> {
use_mlock: args.use_mlock,
flash_attention: args.flash_attention,
max_batch_total_tokens: args.max_batch_total_tokens,
max_batch_size: args.max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(100),
},
tokenizer,