Simplify batching logic

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-05 10:12:39 +00:00
parent d3a772a8dd
commit dbee804129
No known key found for this signature in database

View File

@ -21,6 +21,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn, error, trace};
use tracing::{instrument};
use std::str::FromStr;
use std::mem::replace;
#[derive(Debug, Clone, Copy)]
pub enum LlamacppSplitMode {
@ -466,35 +467,29 @@ impl LlamacppBackend {
let mut n_tokens = 0;
let mut requests = Vec::with_capacity(conf.max_batch_size);
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
if !requests.is_empty() {
let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
*n_tokens = 0;
}
};
loop {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed
Ok(Some(request)) => {
if requests.len() + 1 == conf.max_batch_size {
requests.push(request);
let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new();
continue;
}
let n_tokens_to_add = request.input_ids.len();
if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens as usize {
let _ = sync_tx.send(requests);
n_tokens = n_tokens_to_add;
requests = vec![request];
continue;
flush(&mut requests, &mut n_tokens);
}
n_tokens += n_tokens_to_add;
requests.push(request);
},
Err(_) => {
if !requests.is_empty() {
let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new();
if requests.len() == conf.max_batch_size {
flush(&mut requests, &mut n_tokens);
}
}
},
Ok(None) => break, // closed
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
}
}
});