From bd0cc9905c672b643aecc450a40b9f3be26e18b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 30 Jan 2025 13:41:35 +0000 Subject: [PATCH] Get rid of llama_batch_get_one() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 128 +++++++++++++++++++++++-------- backends/llamacpp/src/main.rs | 2 + 2 files changed, 100 insertions(+), 30 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index bb61b4ad..80c04bc3 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -7,7 +7,7 @@ mod bindings { } use async_trait::async_trait; use std::ffi::CString; -use std::sync::Once; +use std::sync::{mpsc, Once}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::{ValidGenerateRequest}; use text_generation_router::{FinishReason, Token}; @@ -15,8 +15,8 @@ use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::{watch, oneshot}; -use tokio::task::spawn_blocking; -use tokio::time::Instant; +use tokio::task::{spawn, spawn_blocking}; +use tokio::time::{Duration, Instant, timeout}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, info, warn, error, trace}; use tracing::{instrument}; @@ -24,6 +24,8 @@ use tracing::{instrument}; pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: u32, + pub batch_size: usize, + pub batch_timeout: Duration, pub n_threads: i32, pub use_mmap: bool, pub use_mlock: bool, @@ -85,6 +87,7 @@ struct Llamacpp { model: *mut bindings::llama_model, ctx: *mut bindings::llama_context, vocab: *const bindings::llama_vocab, + batch: bindings::llama_batch, n_ctx: u32, } @@ -138,8 +141,39 @@ impl Llamacpp { if vocab.is_null() { return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); } - Ok(Llamacpp{model, ctx, vocab, n_ctx}) + let batch = unsafe { + bindings::llama_batch_init(4096, 0, 5) + }; + // TODO check batch + Ok(Llamacpp{model, ctx, vocab, n_ctx, batch}) } + + fn batch_push( + &mut self, + token: bindings::llama_token, + pos: bindings::llama_pos, + seq_ids: &[bindings::llama_seq_id], + logits: bool, + ) { + // TODO check evertyhing.. + let n = self.batch.n_tokens as usize; + + unsafe { + *self.batch.token.add(n) = token; + *self.batch.pos.add(n) = pos; + *self.batch.n_seq_id.add(n) = seq_ids.len() as i32; + } + for (i, &seq_id) in seq_ids.iter().enumerate() { + unsafe { + *(*self.batch.seq_id.add(n)).add(i) = seq_id; + } + } + unsafe { + *self.batch.logits.add(n) = logits as i8; + } + self.batch.n_tokens += 1; + } + // useless ? fn warmup(&self) { let mut buf: Vec = Vec::new(); @@ -181,6 +215,7 @@ impl Drop for Llamacpp { if !self.model.is_null() { unsafe { bindings::llama_model_free(self.model) }; } + unsafe { bindings::llama_batch_free(self.batch) }; } } @@ -223,12 +258,12 @@ impl LlamacppSampler { }; let mut failed = false; - for (k, v) in &[("top_k", top_k), - ("top_p", top_p), + for (k, v) in &[( "top_k", top_k ), + ( "top_p", top_p ), ("typical_p", typical_p), - ("temp", temp), + ( "temp", temp ), ("penalties", penalties), - ("dist", dist)] { + ( "dist", dist )] { if v.is_null() { error!("Failed to init {k} sampler"); failed = true; @@ -275,9 +310,33 @@ impl LlamacppBackend { let (status_tx, status_rx) = watch::channel(false); let (ok_tx, ok_rx) = oneshot::channel(); let (tx, mut rx) = unbounded_channel::(); + let (sync_tx, sync_rx) = mpsc::channel(); + + spawn(async move { + let mut requests = Vec::new(); + + loop { + match timeout(conf.batch_timeout, rx.recv()).await { + Ok(None) => break, // closed + Ok(Some(request)) => { + requests.push(request); + if requests.len() >= conf.batch_size { + let _ = sync_tx.send(requests); + requests = Vec::new(); + } + }, + Err(_) => { + if !requests.is_empty() { + let _ = sync_tx.send(requests); + requests = Vec::new(); + } + } + } + } + }); spawn_blocking(move || { - let llamacpp = match Llamacpp::new(conf) { + let mut llamacpp = match Llamacpp::new(conf) { Ok(v) => { let _ = ok_tx.send(Ok(())); v }, Err(e) => { let _ = ok_tx.send(Err(e)); return; }, }; @@ -288,18 +347,25 @@ impl LlamacppBackend { // health() returns true let _ = status_tx.send(true); - while let Some(request) = rx.blocking_recv() { - debug!("Request: {:?}", request); - - let start_time = Instant::now(); + while let Ok(requests) = sync_rx.recv() { // TODO: do a real batch - let mut batch = unsafe { - bindings::llama_batch_get_one( - request.input_ids.as_ptr() as _, - request.input_ids.len() as _, - ) - }; + for (_seq_id, request) in requests.iter().enumerate() { + + debug!("Request: {:?}", request); + let start_time = Instant::now(); + llamacpp.batch.n_tokens = 0; + + for (pos, &token_id) in request.input_ids.iter().enumerate() { + llamacpp.batch_push( + token_id as bindings::llama_token, + pos as bindings::llama_pos, + &[/* seq_id */ 0 as bindings::llama_seq_id], + true, + ); + } + // TODO: close this loop :) + // TODO: move up for perf ? let sampler = match LlamacppSampler::new(&request) { Some(sampler) => sampler, @@ -310,10 +376,10 @@ impl LlamacppBackend { }; let mut text = String::with_capacity(1024); let mut n_tokens: usize = 0; + let mut n_new_tokens: usize = 0; loop { - debug!(?batch); - match unsafe { bindings::llama_decode(llamacpp.ctx, batch) } { + match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } { 0 => { }, 1 => { unsafe { @@ -321,7 +387,7 @@ impl LlamacppBackend { bindings::llama_kv_cache_clear(llamacpp.ctx); } let _ = request.tx.send(Err(InferError::IncompleteGeneration)); - continue; + break; }, _ => { debug!("decode return <0"); @@ -329,9 +395,11 @@ impl LlamacppBackend { break; }, }; - let mut next = sampler.sample(&llamacpp); - n_tokens += 1; - debug!(?n_tokens); + let next = sampler.sample(&llamacpp); + n_tokens += llamacpp.batch.n_tokens as usize; + n_new_tokens += llamacpp.batch.n_tokens as usize; + + debug!("tokens: {n_tokens} new: {n_new_tokens}"); let logits = unsafe { *bindings::llama_get_logits_ith(llamacpp.ctx, -1) @@ -361,7 +429,7 @@ impl LlamacppBackend { let finish: Option = { if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } { Some(FinishReason::EndOfSequenceToken) - } else if n_tokens == request.max_new_tokens { + } else if n_new_tokens == request.max_new_tokens { Some(FinishReason::Length) } else if kv_cache_used_cells == llamacpp.n_ctx as i32 { Some(FinishReason::Length) // TODO: check @@ -375,7 +443,7 @@ impl LlamacppBackend { top_tokens: vec![], generated_text: GeneratedText { text: text, - generated_tokens: n_tokens as _, + generated_tokens: n_new_tokens as _, finish_reason: reason, seed: Some(request.seed as _), }, @@ -388,11 +456,11 @@ impl LlamacppBackend { token: token, top_tokens: vec![], })); - batch = unsafe { - bindings::llama_batch_get_one(&mut next, 1) - }; + llamacpp.batch.n_tokens = 0; + llamacpp.batch_push(next, n_tokens as _, &[0], true); } } + } // TODO remove this }); (Self{tx, status: status_rx}, ok_rx) } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 00d84ceb..800792e5 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -161,6 +161,8 @@ async fn main() -> Result<(), RouterError> { use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention, + batch_size: 5, + batch_timeout: tokio::time::Duration::from_millis(100), }, tokenizer, );