From 161280f3136e643481580536099d3b70752e19c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 1 Feb 2025 10:51:44 +0000 Subject: [PATCH] Only export the latest logits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 2ad0e491..f95157f5 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -474,19 +474,21 @@ impl LlamacppBackend { continue; }, }; + let last_pos = request.input_ids.len() - 1; + for (pos, &token_id) in request.input_ids.iter().enumerate() { llamacpp.batch_push( token_id as bindings::llama_token, pos as bindings::llama_pos, seq_id as bindings::llama_seq_id, - true, // TODO + pos == last_pos, // check samplers ); } seqs.push(LlamacppSeq { id: seq_id, batch_pos: llamacpp.batch.n_tokens as usize - 1, token: -1, - pos: request.input_ids.len() as _, + pos: last_pos as bindings::llama_pos + 1, sampler: sampler, text: String::with_capacity(1024), n_new_tokens: 0,