From 96434a1e7e65ca011051ce661ae6eb9afea88399 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 1 Feb 2025 16:09:51 +0000 Subject: [PATCH] Fix batching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index ebb40380..c07f0812 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -429,12 +429,15 @@ impl LlamacppBackend { requests = Vec::new(); continue; } - if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize { + let n_tokens_to_add = request.input_ids.len(); + + if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens as usize { let _ = sync_tx.send(requests); - n_tokens = request.input_ids.len(); + n_tokens = n_tokens_to_add; requests = vec![request]; continue; } + n_tokens += n_tokens_to_add; requests.push(request); }, Err(_) => { @@ -487,7 +490,7 @@ impl LlamacppBackend { seqs.push(LlamacppSeq { id: seq_id, batch_pos: llamacpp.batch.n_tokens as usize - 1, - token: -1, + token: bindings::LLAMA_TOKEN_NULL, pos: last_pos as bindings::llama_pos + 1, sampler: sampler, text: String::with_capacity(1024),