diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 2ad0e491..f95157f5 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -474,19 +474,21 @@ impl LlamacppBackend { continue; }, }; + let last_pos = request.input_ids.len() - 1; + for (pos, &token_id) in request.input_ids.iter().enumerate() { llamacpp.batch_push( token_id as bindings::llama_token, pos as bindings::llama_pos, seq_id as bindings::llama_seq_id, - true, // TODO + pos == last_pos, // check samplers ); } seqs.push(LlamacppSeq { id: seq_id, batch_pos: llamacpp.batch.n_tokens as usize - 1, token: -1, - pos: request.input_ids.len() as _, + pos: last_pos as bindings::llama_pos + 1, sampler: sampler, text: String::with_capacity(1024), n_new_tokens: 0,