mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-26 01:40:17 +00:00
Get rid of llama_batch_get_one()
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
95e221eece
commit
bd0cc9905c
backends/llamacpp/src
@ -7,7 +7,7 @@ mod bindings {
|
||||
}
|
||||
use async_trait::async_trait;
|
||||
use std::ffi::CString;
|
||||
use std::sync::Once;
|
||||
use std::sync::{mpsc, Once};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{ValidGenerateRequest};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
@ -15,8 +15,8 @@ use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::{watch, oneshot};
|
||||
use tokio::task::spawn_blocking;
|
||||
use tokio::time::Instant;
|
||||
use tokio::task::{spawn, spawn_blocking};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, info, warn, error, trace};
|
||||
use tracing::{instrument};
|
||||
@ -24,6 +24,8 @@ use tracing::{instrument};
|
||||
pub struct LlamacppConfig {
|
||||
pub model_gguf: String,
|
||||
pub n_ctx: u32,
|
||||
pub batch_size: usize,
|
||||
pub batch_timeout: Duration,
|
||||
pub n_threads: i32,
|
||||
pub use_mmap: bool,
|
||||
pub use_mlock: bool,
|
||||
@ -85,6 +87,7 @@ struct Llamacpp {
|
||||
model: *mut bindings::llama_model,
|
||||
ctx: *mut bindings::llama_context,
|
||||
vocab: *const bindings::llama_vocab,
|
||||
batch: bindings::llama_batch,
|
||||
n_ctx: u32,
|
||||
}
|
||||
|
||||
@ -138,8 +141,39 @@ impl Llamacpp {
|
||||
if vocab.is_null() {
|
||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||
}
|
||||
Ok(Llamacpp{model, ctx, vocab, n_ctx})
|
||||
let batch = unsafe {
|
||||
bindings::llama_batch_init(4096, 0, 5)
|
||||
};
|
||||
// TODO check batch
|
||||
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
|
||||
}
|
||||
|
||||
fn batch_push(
|
||||
&mut self,
|
||||
token: bindings::llama_token,
|
||||
pos: bindings::llama_pos,
|
||||
seq_ids: &[bindings::llama_seq_id],
|
||||
logits: bool,
|
||||
) {
|
||||
// TODO check evertyhing..
|
||||
let n = self.batch.n_tokens as usize;
|
||||
|
||||
unsafe {
|
||||
*self.batch.token.add(n) = token;
|
||||
*self.batch.pos.add(n) = pos;
|
||||
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32;
|
||||
}
|
||||
for (i, &seq_id) in seq_ids.iter().enumerate() {
|
||||
unsafe {
|
||||
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
|
||||
}
|
||||
}
|
||||
unsafe {
|
||||
*self.batch.logits.add(n) = logits as i8;
|
||||
}
|
||||
self.batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
// useless ?
|
||||
fn warmup(&self) {
|
||||
let mut buf: Vec<bindings::llama_token> = Vec::new();
|
||||
@ -181,6 +215,7 @@ impl Drop for Llamacpp {
|
||||
if !self.model.is_null() {
|
||||
unsafe { bindings::llama_model_free(self.model) };
|
||||
}
|
||||
unsafe { bindings::llama_batch_free(self.batch) };
|
||||
}
|
||||
}
|
||||
|
||||
@ -223,12 +258,12 @@ impl LlamacppSampler {
|
||||
};
|
||||
let mut failed = false;
|
||||
|
||||
for (k, v) in &[("top_k", top_k),
|
||||
("top_p", top_p),
|
||||
for (k, v) in &[( "top_k", top_k ),
|
||||
( "top_p", top_p ),
|
||||
("typical_p", typical_p),
|
||||
("temp", temp),
|
||||
( "temp", temp ),
|
||||
("penalties", penalties),
|
||||
("dist", dist)] {
|
||||
( "dist", dist )] {
|
||||
if v.is_null() {
|
||||
error!("Failed to init {k} sampler");
|
||||
failed = true;
|
||||
@ -275,9 +310,33 @@ impl LlamacppBackend {
|
||||
let (status_tx, status_rx) = watch::channel(false);
|
||||
let (ok_tx, ok_rx) = oneshot::channel();
|
||||
let (tx, mut rx) = unbounded_channel::<LlamacppRequest>();
|
||||
let (sync_tx, sync_rx) = mpsc::channel();
|
||||
|
||||
spawn(async move {
|
||||
let mut requests = Vec::new();
|
||||
|
||||
loop {
|
||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||
Ok(None) => break, // closed
|
||||
Ok(Some(request)) => {
|
||||
requests.push(request);
|
||||
if requests.len() >= conf.batch_size {
|
||||
let _ = sync_tx.send(requests);
|
||||
requests = Vec::new();
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
if !requests.is_empty() {
|
||||
let _ = sync_tx.send(requests);
|
||||
requests = Vec::new();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
spawn_blocking(move || {
|
||||
let llamacpp = match Llamacpp::new(conf) {
|
||||
let mut llamacpp = match Llamacpp::new(conf) {
|
||||
Ok(v) => { let _ = ok_tx.send(Ok(())); v },
|
||||
Err(e) => { let _ = ok_tx.send(Err(e)); return; },
|
||||
};
|
||||
@ -288,18 +347,25 @@ impl LlamacppBackend {
|
||||
// health() returns true
|
||||
let _ = status_tx.send(true);
|
||||
|
||||
while let Some(request) = rx.blocking_recv() {
|
||||
debug!("Request: {:?}", request);
|
||||
|
||||
let start_time = Instant::now();
|
||||
while let Ok(requests) = sync_rx.recv() {
|
||||
|
||||
// TODO: do a real batch
|
||||
let mut batch = unsafe {
|
||||
bindings::llama_batch_get_one(
|
||||
request.input_ids.as_ptr() as _,
|
||||
request.input_ids.len() as _,
|
||||
)
|
||||
};
|
||||
for (_seq_id, request) in requests.iter().enumerate() {
|
||||
|
||||
debug!("Request: {:?}", request);
|
||||
let start_time = Instant::now();
|
||||
llamacpp.batch.n_tokens = 0;
|
||||
|
||||
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
||||
llamacpp.batch_push(
|
||||
token_id as bindings::llama_token,
|
||||
pos as bindings::llama_pos,
|
||||
&[/* seq_id */ 0 as bindings::llama_seq_id],
|
||||
true,
|
||||
);
|
||||
}
|
||||
// TODO: close this loop :)
|
||||
|
||||
// TODO: move up for perf ?
|
||||
let sampler = match LlamacppSampler::new(&request) {
|
||||
Some(sampler) => sampler,
|
||||
@ -310,10 +376,10 @@ impl LlamacppBackend {
|
||||
};
|
||||
let mut text = String::with_capacity(1024);
|
||||
let mut n_tokens: usize = 0;
|
||||
let mut n_new_tokens: usize = 0;
|
||||
|
||||
loop {
|
||||
debug!(?batch);
|
||||
match unsafe { bindings::llama_decode(llamacpp.ctx, batch) } {
|
||||
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } {
|
||||
0 => { },
|
||||
1 => {
|
||||
unsafe {
|
||||
@ -321,7 +387,7 @@ impl LlamacppBackend {
|
||||
bindings::llama_kv_cache_clear(llamacpp.ctx);
|
||||
}
|
||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||
continue;
|
||||
break;
|
||||
},
|
||||
_ => {
|
||||
debug!("decode return <0");
|
||||
@ -329,9 +395,11 @@ impl LlamacppBackend {
|
||||
break;
|
||||
},
|
||||
};
|
||||
let mut next = sampler.sample(&llamacpp);
|
||||
n_tokens += 1;
|
||||
debug!(?n_tokens);
|
||||
let next = sampler.sample(&llamacpp);
|
||||
n_tokens += llamacpp.batch.n_tokens as usize;
|
||||
n_new_tokens += llamacpp.batch.n_tokens as usize;
|
||||
|
||||
debug!("tokens: {n_tokens} new: {n_new_tokens}");
|
||||
|
||||
let logits = unsafe {
|
||||
*bindings::llama_get_logits_ith(llamacpp.ctx, -1)
|
||||
@ -361,7 +429,7 @@ impl LlamacppBackend {
|
||||
let finish: Option<FinishReason> = {
|
||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
||||
Some(FinishReason::EndOfSequenceToken)
|
||||
} else if n_tokens == request.max_new_tokens {
|
||||
} else if n_new_tokens == request.max_new_tokens {
|
||||
Some(FinishReason::Length)
|
||||
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
|
||||
Some(FinishReason::Length) // TODO: check
|
||||
@ -375,7 +443,7 @@ impl LlamacppBackend {
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: text,
|
||||
generated_tokens: n_tokens as _,
|
||||
generated_tokens: n_new_tokens as _,
|
||||
finish_reason: reason,
|
||||
seed: Some(request.seed as _),
|
||||
},
|
||||
@ -388,11 +456,11 @@ impl LlamacppBackend {
|
||||
token: token,
|
||||
top_tokens: vec![],
|
||||
}));
|
||||
batch = unsafe {
|
||||
bindings::llama_batch_get_one(&mut next, 1)
|
||||
};
|
||||
llamacpp.batch.n_tokens = 0;
|
||||
llamacpp.batch_push(next, n_tokens as _, &[0], true);
|
||||
}
|
||||
}
|
||||
} // TODO remove this
|
||||
});
|
||||
(Self{tx, status: status_rx}, ok_rx)
|
||||
}
|
||||
|
@ -161,6 +161,8 @@ async fn main() -> Result<(), RouterError> {
|
||||
use_mmap: args.use_mmap,
|
||||
use_mlock: args.use_mlock,
|
||||
flash_attention: args.flash_attention,
|
||||
batch_size: 5,
|
||||
batch_timeout: tokio::time::Duration::from_millis(100),
|
||||
},
|
||||
tokenizer,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user