Get rid of llama_batch_get_one()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-30 13:41:35 +00:00
parent 95e221eece
commit bd0cc9905c
No known key found for this signature in database
2 changed files with 100 additions and 30 deletions
backends/llamacpp/src

View File

@ -7,7 +7,7 @@ mod bindings {
}
use async_trait::async_trait;
use std::ffi::CString;
use std::sync::Once;
use std::sync::{mpsc, Once};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
@ -15,8 +15,8 @@ use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::{watch, oneshot};
use tokio::task::spawn_blocking;
use tokio::time::Instant;
use tokio::task::{spawn, spawn_blocking};
use tokio::time::{Duration, Instant, timeout};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn, error, trace};
use tracing::{instrument};
@ -24,6 +24,8 @@ use tracing::{instrument};
pub struct LlamacppConfig {
pub model_gguf: String,
pub n_ctx: u32,
pub batch_size: usize,
pub batch_timeout: Duration,
pub n_threads: i32,
pub use_mmap: bool,
pub use_mlock: bool,
@ -85,6 +87,7 @@ struct Llamacpp {
model: *mut bindings::llama_model,
ctx: *mut bindings::llama_context,
vocab: *const bindings::llama_vocab,
batch: bindings::llama_batch,
n_ctx: u32,
}
@ -138,8 +141,39 @@ impl Llamacpp {
if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
}
Ok(Llamacpp{model, ctx, vocab, n_ctx})
let batch = unsafe {
bindings::llama_batch_init(4096, 0, 5)
};
// TODO check batch
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
}
fn batch_push(
&mut self,
token: bindings::llama_token,
pos: bindings::llama_pos,
seq_ids: &[bindings::llama_seq_id],
logits: bool,
) {
// TODO check evertyhing..
let n = self.batch.n_tokens as usize;
unsafe {
*self.batch.token.add(n) = token;
*self.batch.pos.add(n) = pos;
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32;
}
for (i, &seq_id) in seq_ids.iter().enumerate() {
unsafe {
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
}
}
unsafe {
*self.batch.logits.add(n) = logits as i8;
}
self.batch.n_tokens += 1;
}
// useless ?
fn warmup(&self) {
let mut buf: Vec<bindings::llama_token> = Vec::new();
@ -181,6 +215,7 @@ impl Drop for Llamacpp {
if !self.model.is_null() {
unsafe { bindings::llama_model_free(self.model) };
}
unsafe { bindings::llama_batch_free(self.batch) };
}
}
@ -223,12 +258,12 @@ impl LlamacppSampler {
};
let mut failed = false;
for (k, v) in &[("top_k", top_k),
("top_p", top_p),
for (k, v) in &[( "top_k", top_k ),
( "top_p", top_p ),
("typical_p", typical_p),
("temp", temp),
( "temp", temp ),
("penalties", penalties),
("dist", dist)] {
( "dist", dist )] {
if v.is_null() {
error!("Failed to init {k} sampler");
failed = true;
@ -275,9 +310,33 @@ impl LlamacppBackend {
let (status_tx, status_rx) = watch::channel(false);
let (ok_tx, ok_rx) = oneshot::channel();
let (tx, mut rx) = unbounded_channel::<LlamacppRequest>();
let (sync_tx, sync_rx) = mpsc::channel();
spawn(async move {
let mut requests = Vec::new();
loop {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(None) => break, // closed
Ok(Some(request)) => {
requests.push(request);
if requests.len() >= conf.batch_size {
let _ = sync_tx.send(requests);
requests = Vec::new();
}
},
Err(_) => {
if !requests.is_empty() {
let _ = sync_tx.send(requests);
requests = Vec::new();
}
}
}
}
});
spawn_blocking(move || {
let llamacpp = match Llamacpp::new(conf) {
let mut llamacpp = match Llamacpp::new(conf) {
Ok(v) => { let _ = ok_tx.send(Ok(())); v },
Err(e) => { let _ = ok_tx.send(Err(e)); return; },
};
@ -288,18 +347,25 @@ impl LlamacppBackend {
// health() returns true
let _ = status_tx.send(true);
while let Some(request) = rx.blocking_recv() {
debug!("Request: {:?}", request);
let start_time = Instant::now();
while let Ok(requests) = sync_rx.recv() {
// TODO: do a real batch
let mut batch = unsafe {
bindings::llama_batch_get_one(
request.input_ids.as_ptr() as _,
request.input_ids.len() as _,
)
};
for (_seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
let start_time = Instant::now();
llamacpp.batch.n_tokens = 0;
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as bindings::llama_token,
pos as bindings::llama_pos,
&[/* seq_id */ 0 as bindings::llama_seq_id],
true,
);
}
// TODO: close this loop :)
// TODO: move up for perf ?
let sampler = match LlamacppSampler::new(&request) {
Some(sampler) => sampler,
@ -310,10 +376,10 @@ impl LlamacppBackend {
};
let mut text = String::with_capacity(1024);
let mut n_tokens: usize = 0;
let mut n_new_tokens: usize = 0;
loop {
debug!(?batch);
match unsafe { bindings::llama_decode(llamacpp.ctx, batch) } {
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } {
0 => { },
1 => {
unsafe {
@ -321,7 +387,7 @@ impl LlamacppBackend {
bindings::llama_kv_cache_clear(llamacpp.ctx);
}
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
continue;
break;
},
_ => {
debug!("decode return <0");
@ -329,9 +395,11 @@ impl LlamacppBackend {
break;
},
};
let mut next = sampler.sample(&llamacpp);
n_tokens += 1;
debug!(?n_tokens);
let next = sampler.sample(&llamacpp);
n_tokens += llamacpp.batch.n_tokens as usize;
n_new_tokens += llamacpp.batch.n_tokens as usize;
debug!("tokens: {n_tokens} new: {n_new_tokens}");
let logits = unsafe {
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
@ -361,7 +429,7 @@ impl LlamacppBackend {
let finish: Option<FinishReason> = {
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken)
} else if n_tokens == request.max_new_tokens {
} else if n_new_tokens == request.max_new_tokens {
Some(FinishReason::Length)
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
Some(FinishReason::Length) // TODO: check
@ -375,7 +443,7 @@ impl LlamacppBackend {
top_tokens: vec![],
generated_text: GeneratedText {
text: text,
generated_tokens: n_tokens as _,
generated_tokens: n_new_tokens as _,
finish_reason: reason,
seed: Some(request.seed as _),
},
@ -388,11 +456,11 @@ impl LlamacppBackend {
token: token,
top_tokens: vec![],
}));
batch = unsafe {
bindings::llama_batch_get_one(&mut next, 1)
};
llamacpp.batch.n_tokens = 0;
llamacpp.batch_push(next, n_tokens as _, &[0], true);
}
}
} // TODO remove this
});
(Self{tx, status: status_rx}, ok_rx)
}

View File

@ -161,6 +161,8 @@ async fn main() -> Result<(), RouterError> {
use_mmap: args.use_mmap,
use_mlock: args.use_mlock,
flash_attention: args.flash_attention,
batch_size: 5,
batch_timeout: tokio::time::Duration::from_millis(100),
},
tokenizer,
);