From e7facf692f9a73e00743f9e92f10e2b3c58a47e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 30 Jan 2025 19:50:09 +0000 Subject: [PATCH] Handle max_batch_size MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 14 ++++++++++++-- backends/llamacpp/src/main.rs | 7 +++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 88726756..38b21ce2 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -25,6 +25,7 @@ pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: u32, pub max_batch_total_tokens: u32, + pub max_batch_size: Option, pub batch_timeout: Duration, pub n_threads: i32, pub use_mmap: bool, @@ -320,13 +321,22 @@ impl LlamacppBackend { match timeout(conf.batch_timeout, rx.recv()).await { Ok(None) => break, // closed Ok(Some(request)) => { + if let Some(max_batch_size) = conf.max_batch_size { + if requests.len() + 1 == max_batch_size { + requests.push(request); + let _ = sync_tx.send(requests); + n_tokens = 0; + requests = Vec::new(); + continue; + } + } if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize { let _ = sync_tx.send(requests); n_tokens = request.input_ids.len(); requests = vec![request]; - } else { - requests.push(request); + continue; } + requests.push(request); }, Err(_) => { if !requests.is_empty() { diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index ea6743a7..a9bebd88 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -75,8 +75,10 @@ struct Args { // #[clap(default_value = "20", long, env)] // max_waiting_tokens: usize, -// #[clap(long, env)] -// max_batch_size: Option, + + /// Maximum number of requests per batch + #[clap(long, env)] + max_batch_size: Option, /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] @@ -165,6 +167,7 @@ async fn main() -> Result<(), RouterError> { use_mlock: args.use_mlock, flash_attention: args.flash_attention, max_batch_total_tokens: args.max_batch_total_tokens, + max_batch_size: args.max_batch_size, batch_timeout: tokio::time::Duration::from_millis(100), }, tokenizer,