From d6ded897a888373500040e0f013247bca3081b42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Fri, 31 Jan 2025 12:44:09 +0000 Subject: [PATCH] Add a stupid batch mechanism MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 223 ++++++++++++++++--------------- backends/llamacpp/src/main.rs | 2 +- 2 files changed, 119 insertions(+), 106 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 53f2c098..ba5ca186 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -202,7 +202,7 @@ impl Llamacpp { Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) } - fn batch_clear_logits(&mut self) { + fn _batch_clear_logits(&mut self) { for n in 0..self.batch.n_tokens as usize{ unsafe { *self.batch.logits.add(n) = 0 as i8; @@ -214,24 +214,15 @@ impl Llamacpp { &mut self, token: bindings::llama_token, pos: bindings::llama_pos, - seq_ids: &[bindings::llama_seq_id], + seq_id: bindings::llama_seq_id, logits: bool, ) { - debug!("push {token} {pos} {logits}"); - // TODO check evertyhing.. let n = self.batch.n_tokens as usize; - unsafe { *self.batch.token.add(n) = token; *self.batch.pos.add(n) = pos; - *self.batch.n_seq_id.add(n) = seq_ids.len() as i32; - } - for (i, &seq_id) in seq_ids.iter().enumerate() { - unsafe { - *(*self.batch.seq_id.add(n)).add(i) = seq_id; - } - } - unsafe { + *self.batch.n_seq_id.add(n) = 1; + *(*self.batch.seq_id.add(n)).add(0) = seq_id; *self.batch.logits.add(n) = logits as i8; } self.batch.n_tokens += 1; @@ -375,6 +366,17 @@ impl Drop for LlamacppSampler { } } +struct LlamacppSeq { + id: usize, + batch_pos: usize, + token: bindings::llama_token, + pos: bindings::llama_pos, + sampler: LlamacppSampler, + text: String, + n_new_tokens: usize, + running: bool, +} + static INIT: Once = Once::new(); impl LlamacppBackend { @@ -397,7 +399,7 @@ impl LlamacppBackend { spawn(async move { let mut n_tokens = 0; - let mut requests = Vec::new(); + let mut requests = Vec::with_capacity(conf.max_batch_size); loop { match timeout(conf.batch_timeout, rx.recv()).await { @@ -442,120 +444,131 @@ impl LlamacppBackend { let _ = status_tx.send(true); while let Ok(requests) = sync_rx.recv() { + let start_time = Instant::now(); + let mut seqs: Vec = Vec::with_capacity(requests.len()); + llamacpp.batch.n_tokens = 0; - // TODO: do a real batch - for (_seq_id, request) in requests.iter().enumerate() { - + for (seq_id, request) in requests.iter().enumerate() { debug!("Request: {:?}", request); - let start_time = Instant::now(); - llamacpp.batch.n_tokens = 0; - + let sampler = match LlamacppSampler::new(&request) { + Some(sampler) => sampler, + _ => { + let _ = request.tx.send(Err(InferError::IncompleteGeneration)); + continue; + }, + }; for (pos, &token_id) in request.input_ids.iter().enumerate() { llamacpp.batch_push( token_id as bindings::llama_token, pos as bindings::llama_pos, - &[/* seq_id */ 0 as bindings::llama_seq_id], - true, + seq_id as bindings::llama_seq_id, + true, // TODO ); } - let mut pos = request.input_ids.len(); - - // TODO: close this loop :) - - // TODO: move up for perf ? - let sampler = match LlamacppSampler::new(&request) { - Some(sampler) => sampler, - _ => { - let _ = request.tx.send(Err(InferError::IncompleteGeneration)); - continue; - }, - }; - let mut text = String::with_capacity(1024); - let mut n_tokens: usize = 0; - let mut n_new_tokens: usize = 0; - + seqs.push(LlamacppSeq { + id: seq_id, + batch_pos: llamacpp.batch.n_tokens as usize - 1, + token: -1, + pos: request.input_ids.len() as _, + sampler: sampler, + text: String::with_capacity(1024), + n_new_tokens: 0, + running: true, + }); + } loop { - match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } { - 0 => { }, - 1 => { - unsafe { - // TODO: seq_rm & seq_add if model is compatible - bindings::llama_kv_cache_clear(llamacpp.ctx); - } - let _ = request.tx.send(Err(InferError::IncompleteGeneration)); - break; - }, - _ => { - debug!("decode return <0"); - let _ = request.tx.send(Err(InferError::IncompleteGeneration)); - break; - }, + if llamacpp.batch.n_tokens == 0 { + break; + } + let decode = unsafe { + bindings::llama_decode(llamacpp.ctx, llamacpp.batch) }; - let idx = llamacpp.batch.n_tokens as usize - 1; - let (next, logprob) = sampler.sample(&mut llamacpp, idx); - n_new_tokens += 1; - - debug!("tokens: {n_tokens} new: {n_new_tokens}"); + if decode != 0 { + error!("Failed to decode batch: {decode}"); + if decode == 1 { + unsafe { + bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO + } + } + for seq in seqs.iter_mut() { + let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); + seq.running = false; + } + break; + } let kv_cache_used_cells = unsafe { bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) }; - let piece = match tokenizer.decode(&[next as u32], false) { - Ok(piece) => piece, - Err(e) => { - error!("Failed to decode token: {e}"); - let _ = request.tx.send(Err(InferError::IncompleteGeneration)); - break; - }, - }; - let special = vocab.is_special_token(&piece); + for seq in seqs.iter_mut() { + let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos); + seq.n_new_tokens += 1; + seq.token = next; - if !special { - text.push_str(&piece); - } - let token = Token { - id: next as _, - text: piece, - logprob: logprob, - special: special, - }; - let finish: Option = { - if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } { - Some(FinishReason::EndOfSequenceToken) - } else if n_new_tokens == request.max_new_tokens { - Some(FinishReason::Length) - } else if kv_cache_used_cells == llamacpp.n_ctx as i32 { - Some(FinishReason::Length) // TODO: check - } else { - None + let piece = match tokenizer.decode(&[next as u32], false) { + Ok(piece) => piece, + Err(e) => { + error!("Failed to decode token: {e}"); + let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); + seq.running = false; + break; + }, + }; + let special = vocab.is_special_token(&piece); + + if !special { + seq.text.push_str(&piece); } - }; - if let Some(reason) = finish { - let _ = request.tx.send(Ok(InferStreamResponse::End { + let token = Token { + id: next as _, + text: piece, + logprob: logprob, + special: special, + }; + let finish: Option = { + if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } { + Some(FinishReason::EndOfSequenceToken) + } else if seq.n_new_tokens == requests[seq.id].max_new_tokens { + Some(FinishReason::Length) + } else if kv_cache_used_cells == llamacpp.n_ctx as i32 { + Some(FinishReason::Length) // TODO: check + } else { + None + } + }; + if let Some(reason) = finish { + let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End { + token: token, + top_tokens: vec![], + generated_text: GeneratedText { + text: seq.text.clone(), + generated_tokens: seq.n_new_tokens as _, + finish_reason: reason, + seed: Some(requests[seq.id].seed as _), + }, + start: start_time, + queued: requests[seq.id].time, + })); + seq.running = false; + break; + } + let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate { token: token, top_tokens: vec![], - generated_text: GeneratedText { - text: text, - generated_tokens: n_new_tokens as _, - finish_reason: reason, - seed: Some(request.seed as _), - }, - start: start_time, - queued: request.time, })); - break; } - let _ = request.tx.send(Ok(InferStreamResponse::Intermediate { - token: token, - top_tokens: vec![], - })); + // generate a new batch llamacpp.batch.n_tokens = 0; - // llamacpp.batch_clear_logits(); - llamacpp.batch_push(next, pos as _, &[0], true); - pos += 1; + + for seq in seqs.iter_mut() { + if seq.running { + llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true); + seq.batch_pos = 0; + seq.pos += 1; + } + } } } - } // TODO remove this }); (Self{tx, status: status_rx}, ok_rx) } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 53a83aa1..e1edd72d 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -198,7 +198,7 @@ async fn main() -> Result<(), RouterError> { flash_attention: args.flash_attention, max_batch_total_tokens: args.max_batch_total_tokens, max_batch_size: args.max_batch_size, - batch_timeout: tokio::time::Duration::from_millis(100), + batch_timeout: tokio::time::Duration::from_millis(5), }, tokenizer, );