From 7b0a56f40fc5766bef8c707a800b53462399f31c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sun, 3 Nov 2024 11:17:02 +0100 Subject: [PATCH] feat(backend): fix memory leaking on llama_sampler when the decode ends --- backends/llamacpp/csrc/backend.cpp | 4 ++-- backends/llamacpp/csrc/backend.hpp | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index b88067f8..4b608620 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -29,7 +29,7 @@ namespace huggingface::tgi::backends::llamacpp { batch.logits[batch.n_tokens] = true; } - std::unique_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const { + llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const { auto *pSampler = llama_sampler_chain_init({.no_perf = false}); // Penalties @@ -51,7 +51,7 @@ namespace huggingface::tgi::backends::llamacpp { } llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); - return std::unique_ptr(pSampler); + return llama_sampler_ptr(pSampler, llama_sampler_deleter); } worker_t::worker_t(std::shared_ptr model, const llama_context_params ¶ms) diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 288bf36a..70f99268 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -24,7 +24,10 @@ namespace huggingface::tgi::backends::llamacpp { static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; - typedef std::unique_ptr llama_context_smart_ptr; + typedef std::unique_ptr llama_context_ptr; + + static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); }; + typedef std::unique_ptr llama_sampler_ptr; typedef std::function llama_decode_callback; static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {}; @@ -51,7 +54,7 @@ namespace huggingface::tgi::backends::llamacpp { * @param Pointer to the model data * @return */ - std::unique_ptr into_llama_sampler(const llama_model *pModel) const; + llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const; }; /** @@ -155,7 +158,7 @@ namespace huggingface::tgi::backends::llamacpp { class single_worker_backend_t : backend_base_t { private: - constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr { + constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr { auto llParams = llama_context_default_params(); llParams.flash_attn = true; llParams.n_batch = 1; @@ -165,7 +168,7 @@ namespace huggingface::tgi::backends::llamacpp { return {llama_new_context_with_model(pModel, llParams), llama_context_deleter}; }; - llama_context_smart_ptr mContext_; + llama_context_ptr mContext_; worker_t mWorker_; public: @@ -185,7 +188,7 @@ namespace huggingface::tgi::backends::llamacpp { class multi_worker_backend_t : backend_base_t { private: - llama_context_smart_ptr mContext_; + llama_context_ptr mContext_; public: std::expected generate(