From 695b1292e95ae70b9db228676f073e11e8ec711e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 5 Feb 2025 15:42:59 +0000 Subject: [PATCH] Ensure all samplers are freed on error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index d81137e6..fa0e7beb 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -365,14 +365,17 @@ impl LlamacppSampler { let dist = unsafe { llamacpp::sampler_init_dist(req.seed) }; + let all = &[ + ("top_k", top_k), + ("top_p", top_p), + ("typical_p", typical_p), + ("temp", temp), + ("penalties", penalties), + ("dist", dist), + ]; let mut failed = false; - for (k, v) in &[( "top_k", top_k ), - ( "top_p", top_p ), - ("typical_p", typical_p), - ( "temp", temp ), - ("penalties", penalties), - ( "dist", dist )] { + for (k, v) in all { if v.is_null() { error!("Failed to init {k} sampler"); failed = true; @@ -381,6 +384,7 @@ impl LlamacppSampler { } } if failed { + unsafe { llamacpp::sampler_free(chain) }; None } else { Some(LlamacppSampler{chain})