Add a stupid batch mechanism

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-01-31 12:44:09 +00:00
parent e07835c5b5
commit d6ded897a8
No known key found for this signature in database
2 changed files with 119 additions and 106 deletions

View File

@ -202,7 +202,7 @@ impl Llamacpp {
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
}
fn batch_clear_logits(&mut self) {
fn _batch_clear_logits(&mut self) {
for n in 0..self.batch.n_tokens as usize{
unsafe {
*self.batch.logits.add(n) = 0 as i8;
@ -214,24 +214,15 @@ impl Llamacpp {
&mut self,
token: bindings::llama_token,
pos: bindings::llama_pos,
seq_ids: &[bindings::llama_seq_id],
seq_id: bindings::llama_seq_id,
logits: bool,
) {
debug!("push {token} {pos} {logits}");
// TODO check evertyhing..
let n = self.batch.n_tokens as usize;
unsafe {
*self.batch.token.add(n) = token;
*self.batch.pos.add(n) = pos;
*self.batch.n_seq_id.add(n) = seq_ids.len() as i32;
}
for (i, &seq_id) in seq_ids.iter().enumerate() {
unsafe {
*(*self.batch.seq_id.add(n)).add(i) = seq_id;
}
}
unsafe {
*self.batch.n_seq_id.add(n) = 1;
*(*self.batch.seq_id.add(n)).add(0) = seq_id;
*self.batch.logits.add(n) = logits as i8;
}
self.batch.n_tokens += 1;
@ -375,6 +366,17 @@ impl Drop for LlamacppSampler {
}
}
struct LlamacppSeq {
id: usize,
batch_pos: usize,
token: bindings::llama_token,
pos: bindings::llama_pos,
sampler: LlamacppSampler,
text: String,
n_new_tokens: usize,
running: bool,
}
static INIT: Once = Once::new();
impl LlamacppBackend {
@ -397,7 +399,7 @@ impl LlamacppBackend {
spawn(async move {
let mut n_tokens = 0;
let mut requests = Vec::new();
let mut requests = Vec::with_capacity(conf.max_batch_size);
loop {
match timeout(conf.batch_timeout, rx.recv()).await {
@ -442,27 +444,12 @@ impl LlamacppBackend {
let _ = status_tx.send(true);
while let Ok(requests) = sync_rx.recv() {
// TODO: do a real batch
for (_seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
let start_time = Instant::now();
let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());
llamacpp.batch.n_tokens = 0;
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as bindings::llama_token,
pos as bindings::llama_pos,
&[/* seq_id */ 0 as bindings::llama_seq_id],
true,
);
}
let mut pos = request.input_ids.len();
// TODO: close this loop :)
// TODO: move up for perf ?
for (seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
let sampler = match LlamacppSampler::new(&request) {
Some(sampler) => sampler,
_ => {
@ -470,48 +457,67 @@ impl LlamacppBackend {
continue;
},
};
let mut text = String::with_capacity(1024);
let mut n_tokens: usize = 0;
let mut n_new_tokens: usize = 0;
loop {
match unsafe { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) } {
0 => { },
1 => {
unsafe {
// TODO: seq_rm & seq_add if model is compatible
bindings::llama_kv_cache_clear(llamacpp.ctx);
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as bindings::llama_token,
pos as bindings::llama_pos,
seq_id as bindings::llama_seq_id,
true, // TODO
);
}
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
seqs.push(LlamacppSeq {
id: seq_id,
batch_pos: llamacpp.batch.n_tokens as usize - 1,
token: -1,
pos: request.input_ids.len() as _,
sampler: sampler,
text: String::with_capacity(1024),
n_new_tokens: 0,
running: true,
});
}
loop {
if llamacpp.batch.n_tokens == 0 {
break;
},
_ => {
debug!("decode return <0");
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
break;
},
}
let decode = unsafe {
bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
};
let idx = llamacpp.batch.n_tokens as usize - 1;
let (next, logprob) = sampler.sample(&mut llamacpp, idx);
n_new_tokens += 1;
debug!("tokens: {n_tokens} new: {n_new_tokens}");
if decode != 0 {
error!("Failed to decode batch: {decode}");
if decode == 1 {
unsafe {
bindings::llama_kv_cache_clear(llamacpp.ctx); // TODO
}
}
for seq in seqs.iter_mut() {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false;
}
break;
}
let kv_cache_used_cells = unsafe {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
};
for seq in seqs.iter_mut() {
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
seq.n_new_tokens += 1;
seq.token = next;
let piece = match tokenizer.decode(&[next as u32], false) {
Ok(piece) => piece,
Err(e) => {
error!("Failed to decode token: {e}");
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false;
break;
},
};
let special = vocab.is_special_token(&piece);
if !special {
text.push_str(&piece);
seq.text.push_str(&piece);
}
let token = Token {
id: next as _,
@ -522,7 +528,7 @@ impl LlamacppBackend {
let finish: Option<FinishReason> = {
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken)
} else if n_new_tokens == request.max_new_tokens {
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
Some(FinishReason::Length)
} else if kv_cache_used_cells == llamacpp.n_ctx as i32 {
Some(FinishReason::Length) // TODO: check
@ -531,31 +537,38 @@ impl LlamacppBackend {
}
};
if let Some(reason) = finish {
let _ = request.tx.send(Ok(InferStreamResponse::End {
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
token: token,
top_tokens: vec![],
generated_text: GeneratedText {
text: text,
generated_tokens: n_new_tokens as _,
text: seq.text.clone(),
generated_tokens: seq.n_new_tokens as _,
finish_reason: reason,
seed: Some(request.seed as _),
seed: Some(requests[seq.id].seed as _),
},
start: start_time,
queued: request.time,
queued: requests[seq.id].time,
}));
seq.running = false;
break;
}
let _ = request.tx.send(Ok(InferStreamResponse::Intermediate {
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::Intermediate {
token: token,
top_tokens: vec![],
}));
}
// generate a new batch
llamacpp.batch.n_tokens = 0;
// llamacpp.batch_clear_logits();
llamacpp.batch_push(next, pos as _, &[0], true);
pos += 1;
for seq in seqs.iter_mut() {
if seq.running {
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.batch_pos = 0;
seq.pos += 1;
}
}
}
}
} // TODO remove this
});
(Self{tx, status: status_rx}, ok_rx)
}

View File

@ -198,7 +198,7 @@ async fn main() -> Result<(), RouterError> {
flash_attention: args.flash_attention,
max_batch_total_tokens: args.max_batch_total_tokens,
max_batch_size: args.max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(100),
batch_timeout: tokio::time::Duration::from_millis(5),
},
tokenizer,
);