diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 5b44c4e6..c76f0308 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -212,11 +212,9 @@ impl Llamacpp { Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) } - fn _batch_clear_logits(&mut self) { - for n in 0..self.batch.n_tokens as usize{ - unsafe { - *self.batch.logits.add(n) = 0 as i8; - } + fn clear_kv_cache(&mut self, seq_id: bindings::llama_seq_id) { + unsafe { + bindings::llama_kv_cache_seq_rm(self.ctx, seq_id, -1, -1); } } @@ -473,11 +471,8 @@ impl LlamacppBackend { bindings::llama_decode(llamacpp.ctx, llamacpp.batch) }; if decode != 0 { - warn!("llama_decode failed: kv cache clear + sync"); - unsafe { - bindings::llama_kv_cache_clear(llamacpp.ctx); - bindings::llama_synchronize(llamacpp.ctx); - } + warn!("llama_decode failed, clearing kv cache"); + llamacpp.clear_kv_cache(-1); for seq in seqs.iter_mut() { let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); seq.running = false; @@ -555,9 +550,7 @@ impl LlamacppBackend { seq.batch_pos = llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true); seq.pos += 1; } else { - unsafe { - bindings::llama_kv_cache_seq_rm(llamacpp.ctx, seq.id as _, -1, -1); - } + llamacpp.clear_kv_cache(seq.id as _); } } }