From 11c593dc69f9c7b800cd0dbac73e1e00d696867a Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 4 Nov 2024 00:11:55 +0100 Subject: [PATCH] feat(backend): make eog clearer on c++ side --- backends/llamacpp/csrc/backend.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 733a826a7..79c09a26c 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -95,7 +95,7 @@ namespace huggingface::tgi::backends::llamacpp { // Decode auto n_decoded_tokens = 0; - for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) { + for (bool generating = true; generating; ++n_decoded_tokens) { const auto callback_ = callback.value_or(llama_void_callback); #ifdef TGI_LLAMACPP_BACKEND_DEBUG @@ -108,24 +108,27 @@ namespace huggingface::tgi::backends::llamacpp { const auto status = llama_decode(context, batch); #endif batch.n_tokens = 0; - if (LLAMA_SUCCESS(status)) { + if (LLAMA_SUCCESS(status)) [[likely]] { // Sample the new token auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); + auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id); auto new_token_logits = 0.0f; // TODO: return logit - auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); - auto effective_n_decoded_tokens = n_decoded_tokens + 1; - if (!generation_context.generation_params.ignore_eos_token) { - generation_context.generated_tokens[n_decoded_tokens] = new_token_id; - generating = !is_eos; - } + // Push the token to the generated vector on Rust side + generation_context.generated_tokens[n_decoded_tokens] = new_token_id; + + // Handle termination cases + const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; + const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog; + + generating = !(has_reach_max_tokens | has_reach_eog); // Bubble up the generated token if a callback is provided std::invoke(std::forward(callback_), new_token_id, new_token_logits, - is_eos || effective_n_decoded_tokens == max_new_tokens, - effective_n_decoded_tokens); + !generating, + n_decoded_tokens + 1); batch = llama_batch_get_one(&new_token_id, 1); }