mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Simplify batching logic
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
d3a772a8dd
commit
dbee804129
@ -21,6 +21,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||||||
use tracing::{debug, info, warn, error, trace};
|
use tracing::{debug, info, warn, error, trace};
|
||||||
use tracing::{instrument};
|
use tracing::{instrument};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use std::mem::replace;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum LlamacppSplitMode {
|
pub enum LlamacppSplitMode {
|
||||||
@ -466,35 +467,29 @@ impl LlamacppBackend {
|
|||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::with_capacity(conf.max_batch_size);
|
let mut requests = Vec::with_capacity(conf.max_batch_size);
|
||||||
|
|
||||||
|
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
|
||||||
|
if !requests.is_empty() {
|
||||||
|
let _ = sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
|
||||||
|
*n_tokens = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
loop {
|
loop {
|
||||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||||
Ok(None) => break, // closed
|
|
||||||
Ok(Some(request)) => {
|
Ok(Some(request)) => {
|
||||||
if requests.len() + 1 == conf.max_batch_size {
|
|
||||||
requests.push(request);
|
|
||||||
let _ = sync_tx.send(requests);
|
|
||||||
n_tokens = 0;
|
|
||||||
requests = Vec::new();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let n_tokens_to_add = request.input_ids.len();
|
let n_tokens_to_add = request.input_ids.len();
|
||||||
|
|
||||||
if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens as usize {
|
if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens as usize {
|
||||||
let _ = sync_tx.send(requests);
|
flush(&mut requests, &mut n_tokens);
|
||||||
n_tokens = n_tokens_to_add;
|
|
||||||
requests = vec![request];
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
n_tokens += n_tokens_to_add;
|
n_tokens += n_tokens_to_add;
|
||||||
requests.push(request);
|
requests.push(request);
|
||||||
},
|
|
||||||
Err(_) => {
|
if requests.len() == conf.max_batch_size {
|
||||||
if !requests.is_empty() {
|
flush(&mut requests, &mut n_tokens);
|
||||||
let _ = sync_tx.send(requests);
|
|
||||||
n_tokens = 0;
|
|
||||||
requests = Vec::new();
|
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
Ok(None) => break, // closed
|
||||||
|
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user