diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 80c04bc3..88726756 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -24,7 +24,7 @@ use tracing::{instrument}; pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: u32, - pub batch_size: usize, + pub max_batch_total_tokens: u32, pub batch_timeout: Duration, pub n_threads: i32, pub use_mmap: bool, @@ -142,7 +142,7 @@ impl Llamacpp { return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); } let batch = unsafe { - bindings::llama_batch_init(4096, 0, 5) + bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1) }; // TODO check batch Ok(Llamacpp{model, ctx, vocab, n_ctx, batch}) @@ -313,21 +313,25 @@ impl LlamacppBackend { let (sync_tx, sync_rx) = mpsc::channel(); spawn(async move { + let mut n_tokens = 0; let mut requests = Vec::new(); loop { match timeout(conf.batch_timeout, rx.recv()).await { Ok(None) => break, // closed Ok(Some(request)) => { - requests.push(request); - if requests.len() >= conf.batch_size { + if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize { let _ = sync_tx.send(requests); - requests = Vec::new(); + n_tokens = request.input_ids.len(); + requests = vec![request]; + } else { + requests.push(request); } }, Err(_) => { if !requests.is_empty() { let _ = sync_tx.send(requests); + n_tokens = 0; requests = Vec::new(); } } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 800792e5..ea6743a7 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -68,8 +68,11 @@ struct Args { // waiting_served_ratio: f32, // #[clap(default_value = "4096", long, env)] // max_batch_prefill_tokens: u32, -// #[clap(long, env)] -// max_batch_total_tokens: Option, + + /// Maximum tokens within a batch + #[clap(default_value = "1024", long, env)] + max_batch_total_tokens: u32, + // #[clap(default_value = "20", long, env)] // max_waiting_tokens: usize, // #[clap(long, env)] @@ -155,14 +158,14 @@ async fn main() -> Result<(), RouterError> { let (backend, ok) = LlamacppBackend::new( LlamacppConfig { - model_gguf: args.model_gguf, - n_ctx: args.n_ctx, - n_threads: args.n_threads, - use_mmap: args.use_mmap, - use_mlock: args.use_mlock, - flash_attention: args.flash_attention, - batch_size: 5, - batch_timeout: tokio::time::Duration::from_millis(100), + model_gguf: args.model_gguf, + n_ctx: args.n_ctx, + n_threads: args.n_threads, + use_mmap: args.use_mmap, + use_mlock: args.use_mlock, + flash_attention: args.flash_attention, + max_batch_total_tokens: args.max_batch_total_tokens, + batch_timeout: tokio::time::Duration::from_millis(100), }, tokenizer, );