Use max_batch_total_tokens

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-30 15:12:55 +00:00
parent bd0cc9905c
commit 3eb4823f3e
No known key found for this signature in database
2 changed files with 22 additions and 15 deletions

View File

@ -24,7 +24,7 @@ use tracing::{instrument};
pub struct LlamacppConfig { pub struct LlamacppConfig {
pub model_gguf: String, pub model_gguf: String,
pub n_ctx: u32, pub n_ctx: u32,
pub batch_size: usize, pub max_batch_total_tokens: u32,
pub batch_timeout: Duration, pub batch_timeout: Duration,
pub n_threads: i32, pub n_threads: i32,
pub use_mmap: bool, pub use_mmap: bool,
@ -142,7 +142,7 @@ impl Llamacpp {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
} }
let batch = unsafe { let batch = unsafe {
bindings::llama_batch_init(4096, 0, 5) bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
}; };
// TODO check batch // TODO check batch
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch}) Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
@ -313,21 +313,25 @@ impl LlamacppBackend {
let (sync_tx, sync_rx) = mpsc::channel(); let (sync_tx, sync_rx) = mpsc::channel();
spawn(async move { spawn(async move {
let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
loop { loop {
match timeout(conf.batch_timeout, rx.recv()).await { match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed Ok(None) => break, // closed
Ok(Some(request)) => { Ok(Some(request)) => {
requests.push(request); if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
if requests.len() >= conf.batch_size {
let _ = sync_tx.send(requests); let _ = sync_tx.send(requests);
requests = Vec::new(); n_tokens = request.input_ids.len();
requests = vec![request];
} else {
requests.push(request);
} }
}, },
Err(_) => { Err(_) => {
if !requests.is_empty() { if !requests.is_empty() {
let _ = sync_tx.send(requests); let _ = sync_tx.send(requests);
n_tokens = 0;
requests = Vec::new(); requests = Vec::new();
} }
} }

View File

@ -68,8 +68,11 @@ struct Args {
// waiting_served_ratio: f32, // waiting_served_ratio: f32,
// #[clap(default_value = "4096", long, env)] // #[clap(default_value = "4096", long, env)]
// max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
// #[clap(long, env)]
// max_batch_total_tokens: Option<u32>, /// Maximum tokens within a batch
#[clap(default_value = "1024", long, env)]
max_batch_total_tokens: u32,
// #[clap(default_value = "20", long, env)] // #[clap(default_value = "20", long, env)]
// max_waiting_tokens: usize, // max_waiting_tokens: usize,
// #[clap(long, env)] // #[clap(long, env)]
@ -155,14 +158,14 @@ async fn main() -> Result<(), RouterError> {
let (backend, ok) = LlamacppBackend::new( let (backend, ok) = LlamacppBackend::new(
LlamacppConfig { LlamacppConfig {
model_gguf: args.model_gguf, model_gguf: args.model_gguf,
n_ctx: args.n_ctx, n_ctx: args.n_ctx,
n_threads: args.n_threads, n_threads: args.n_threads,
use_mmap: args.use_mmap, use_mmap: args.use_mmap,
use_mlock: args.use_mlock, use_mlock: args.use_mlock,
flash_attention: args.flash_attention, flash_attention: args.flash_attention,
batch_size: 5, max_batch_total_tokens: args.max_batch_total_tokens,
batch_timeout: tokio::time::Duration::from_millis(100), batch_timeout: tokio::time::Duration::from_millis(100),
}, },
tokenizer, tokenizer,
); );