Add a stupid batch mechanism

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-31 12:44:09 +00:00
parent e07835c5b5
commit d6ded897a8
No known key found for this signature in database
2 changed files with 119 additions and 106 deletions

View File

@ -202,7 +202,7 @@ impl Llamacpp {
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
} }
fn batch_clear_logits(&mut self) { fn _batch_clear_logits(&mut self) {
for n in 0..self.batch.n_tokens as usize{ for n in 0..self.batch.n_tokens as usize{
unsafe { unsafe {
*self.batch.logits.add(n) = 0 as i8; *self.batch.logits.add(n) = 0 as i8;
@ -214,24 +214,15 @@ impl Llamacpp {
&mut self, &mut self,
token: bindings::llama_token, token: bindings::llama_token,
pos: bindings::llama_pos, pos: bindings::llama_pos,
seq_ids: &[bindings::llama_seq_id], seq_id: bindings::llama_seq_id,
logits: bool, logits: bool,
) { ) {
debug!("push {token} {pos} {logits}");
// TODO check evertyhing..
let n = self.batch.n_tokens as usize; let n = self.batch.n_tokens as usize;
unsafe { unsafe {
*self.batch.token.add(n) = token; *self.batch.token.add(n) = token;
*self.batch.pos.add(n) = pos; *self.batch.pos.add(n) = pos;
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32; *self.batch.n_seq_id.add(n) = 1;
} *(*self.batch.seq_id.add(n)).add(0) = seq_id;
for (i, &seq_id) in seq_ids.iter().enumerate() {
unsafe {
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
}
}
unsafe {
*self.batch.logits.add(n) = logits as i8; *self.batch.logits.add(n) = logits as i8;
} }
self.batch.n_tokens += 1; self.batch.n_tokens += 1;
@ -375,6 +366,17 @@ impl Drop for LlamacppSampler {
} }
} }
struct LlamacppSeq {
id: usize,
batch_pos: usize,
token: bindings::llama_token,
pos: bindings::llama_pos,
sampler: LlamacppSampler,
text: String,
n_new_tokens: usize,
running: bool,
}
static INIT: Once = Once::new(); static INIT: Once = Once::new();
impl LlamacppBackend { impl LlamacppBackend {
@ -397,7 +399,7 @@ impl LlamacppBackend {
spawn(async move { spawn(async move {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::with_capacity(conf.max_batch_size);
loop { loop {
match timeout(conf.batch_timeout, rx.recv()).await { match timeout(conf.batch_timeout, rx.recv()).await {
@ -442,27 +444,12 @@ impl LlamacppBackend {
let _ = status_tx.send(true); let _ = status_tx.send(true);
while let Ok(requests) = sync_rx.recv() { while let Ok(requests) = sync_rx.recv() {
// TODO: do a real batch
for (_seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
let start_time = Instant::now(); let start_time = Instant::now();
let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());
llamacpp.batch.n_tokens = 0; llamacpp.batch.n_tokens = 0;
for (pos, &token_id) in request.input_ids.iter().enumerate() { for (seq_id, request) in requests.iter().enumerate() {
llamacpp.batch_push( debug!("Request: {:?}", request);
token_id as bindings::llama_token,
pos as bindings::llama_pos,
&[/* seq_id */ 0 as bindings::llama_seq_id],
true,
);
}
let mut pos = request.input_ids.len();
// TODO: close this loop :)
// TODO: move up for perf ?
let sampler = match LlamacppSampler::new(&request) { let sampler = match LlamacppSampler::new(&request) {
Some(sampler) => sampler, Some(sampler) => sampler,
_ => { _ => {
@ -470,48 +457,67 @@ impl LlamacppBackend {
continue; continue;
}, },
}; };
let mut text = String::with_capacity(1024); for (pos, &token_id) in request.input_ids.iter().enumerate() {
let mut n_tokens: usize = 0; llamacpp.batch_push(
let mut n_new_tokens: usize = 0; token_id as bindings::llama_token,
pos as bindings::llama_pos,
loop { seq_id as bindings::llama_seq_id,
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } { true, // TODO
0 => { }, );
1 => {
unsafe {
// TODO: seq_rm & seq_add if model is compatible
bindings::llama_kv_cache_clear(llamacpp.ctx);
} }
let _ = request.tx.send(Err(InferError::IncompleteGeneration)); seqs.push(LlamacppSeq {
id: seq_id,
batch_pos: llamacpp.batch.n_tokens as usize - 1,
token: -1,
pos: request.input_ids.len() as _,
sampler: sampler,
text: String::with_capacity(1024),
n_new_tokens: 0,
running: true,
});
}
loop {
if llamacpp.batch.n_tokens == 0 {
break; break;
}, }
_ => { let decode = unsafe {
debug!("decode return <0"); bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
break;
},
}; };
let idx = llamacpp.batch.n_tokens as usize - 1; if decode != 0 {
let (next, logprob) = sampler.sample(&mut llamacpp, idx); error!("Failed to decode batch: {decode}");
n_new_tokens += 1;
debug!("tokens: {n_tokens} new: {n_new_tokens}");
if decode == 1 {
unsafe {
bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO
}
}
for seq in seqs.iter_mut() {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false;
}
break;
}
let kv_cache_used_cells = unsafe { let kv_cache_used_cells = unsafe {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
}; };
for seq in seqs.iter_mut() {
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
seq.n_new_tokens += 1;
seq.token = next;
let piece = match tokenizer.decode(&[next as u32], false) { let piece = match tokenizer.decode(&[next as u32], false) {
Ok(piece) => piece, Ok(piece) => piece,
Err(e) => { Err(e) => {
error!("Failed to decode token: {e}"); error!("Failed to decode token: {e}");
let _ = request.tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false;
break; break;
}, },
}; };
let special = vocab.is_special_token(&piece); let special = vocab.is_special_token(&piece);
if !special { if !special {
text.push_str(&piece); seq.text.push_str(&piece);
} }
let token = Token { let token = Token {
id: next as _, id: next as _,
@ -522,7 +528,7 @@ impl LlamacppBackend {
let finish: Option<FinishReason> = { let finish: Option<FinishReason> = {
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } { if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken) Some(FinishReason::EndOfSequenceToken)
} else if n_new_tokens == request.max_new_tokens { } else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
Some(FinishReason::Length) Some(FinishReason::Length)
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 { } else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
Some(FinishReason::Length) // TODO: check Some(FinishReason::Length) // TODO: check
@ -531,31 +537,38 @@ impl LlamacppBackend {
} }
}; };
if let Some(reason) = finish { if let Some(reason) = finish {
let _ = request.tx.send(Ok(InferStreamResponse::End { let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
token: token, token: token,
top_tokens: vec![], top_tokens: vec![],
generated_text: GeneratedText { generated_text: GeneratedText {
text: text, text: seq.text.clone(),
generated_tokens: n_new_tokens as _, generated_tokens: seq.n_new_tokens as _,
finish_reason: reason, finish_reason: reason,
seed: Some(request.seed as _), seed: Some(requests[seq.id].seed as _),
}, },
start: start_time, start: start_time,
queued: request.time, queued: requests[seq.id].time,
})); }));
seq.running = false;
break; break;
} }
let _ = request.tx.send(Ok(InferStreamResponse::Intermediate { let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
token: token, token: token,
top_tokens: vec![], top_tokens: vec![],
})); }));
}
// generate a new batch
llamacpp.batch.n_tokens = 0; llamacpp.batch.n_tokens = 0;
// llamacpp.batch_clear_logits();
llamacpp.batch_push(next, pos as _, &[0], true); for seq in seqs.iter_mut() {
pos += 1; if seq.running {
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.batch_pos = 0;
seq.pos += 1;
}
}
} }
} }
} // TODO remove this
}); });
(Self{tx, status: status_rx}, ok_rx) (Self{tx, status: status_rx}, ok_rx)
} }

View File

@ -198,7 +198,7 @@ async fn main() -> Result<(), RouterError> {
flash_attention: args.flash_attention, flash_attention: args.flash_attention,
max_batch_total_tokens: args.max_batch_total_tokens, max_batch_total_tokens: args.max_batch_total_tokens,
max_batch_size: args.max_batch_size, max_batch_size: args.max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(100), batch_timeout: tokio::time::Duration::from_millis(5),
}, },
tokenizer, tokenizer,
); );