From ea28332bb387bf6f0437f44f463e231809189409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 1 Feb 2025 20:40:59 +0000 Subject: [PATCH] Cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 5b44c4e6..c76f0308 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -212,11 +212,9 @@ impl Llamacpp { Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) } - fn _batch_clear_logits(&mut self) { - for n in 0..self.batch.n_tokens as usize{ - unsafe { - *self.batch.logits.add(n) = 0 as i8; - } + fn clear_kv_cache(&mut self, seq_id: bindings::llama_seq_id) { + unsafe { + bindings::llama_kv_cache_seq_rm(self.ctx, seq_id, -1, -1); } } @@ -473,11 +471,8 @@ impl LlamacppBackend { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) }; if decode != 0 { - warn!("llama_decode failed: kv cache clear + sync"); - unsafe { - bindings::llama_kv_cache_clear(llamacpp.ctx); - bindings::llama_synchronize(llamacpp.ctx); - } + warn!("llama_decode failed, clearing kv cache"); + llamacpp.clear_kv_cache(-1); for seq in seqs.iter_mut() { let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); seq.running = false; @@ -555,9 +550,7 @@ impl LlamacppBackend { seq.batch_pos = llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true); seq.pos += 1; } else { - unsafe { - bindings::llama_kv_cache_seq_rm(llamacpp.ctx, seq.id as _, -1, -1); - } + llamacpp.clear_kv_cache(seq.id as _); } } }