From dbee80412967b6f52da138acf8b39efb81740234 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 5 Feb 2025 10:12:39 +0000 Subject: [PATCH] Simplify batching logic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 33 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 870798e7..c6f4e925 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -21,6 +21,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, info, warn, error, trace}; use tracing::{instrument}; use std::str::FromStr; +use std::mem::replace; #[derive(Debug, Clone, Copy)] pub enum LlamacppSplitMode { @@ -466,35 +467,29 @@ impl LlamacppBackend { let mut n_tokens = 0; let mut requests = Vec::with_capacity(conf.max_batch_size); + let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| { + if !requests.is_empty() { + let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size))); + *n_tokens = 0; + } + }; loop { match timeout(conf.batch_timeout, rx.recv()).await { - Ok(None) => break, // closed Ok(Some(request)) => { - if requests.len() + 1 == conf.max_batch_size { - requests.push(request); - let _ = sync_tx.send(requests); - n_tokens = 0; - requests = Vec::new(); - continue; - } let n_tokens_to_add = request.input_ids.len(); if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens as usize { - let _ = sync_tx.send(requests); - n_tokens = n_tokens_to_add; - requests = vec![request]; - continue; + flush(&mut requests, &mut n_tokens); } n_tokens += n_tokens_to_add; requests.push(request); - }, - Err(_) => { - if !requests.is_empty() { - let _ = sync_tx.send(requests); - n_tokens = 0; - requests = Vec::new(); + + if requests.len() == conf.max_batch_size { + flush(&mut requests, &mut n_tokens); } - } + }, + Ok(None) => break, // closed + Err(_) => flush(&mut requests, &mut n_tokens), // timeout } } });