From 06424aa9ff44a7d3edee24cb8ce7de5681222184 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sun, 3 Nov 2024 23:50:46 +0100 Subject: [PATCH] feat(backend): correctly handle the max_new_tokens case for is_eos --- backends/llamacpp/csrc/backend.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 54e41a14..733a826a 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -113,6 +113,7 @@ namespace huggingface::tgi::backends::llamacpp { auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto new_token_logits = 0.0f; // TODO: return logit auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); + auto effective_n_decoded_tokens = n_decoded_tokens + 1; if (!generation_context.generation_params.ignore_eos_token) { generation_context.generated_tokens[n_decoded_tokens] = new_token_id; @@ -121,7 +122,10 @@ namespace huggingface::tgi::backends::llamacpp { // Bubble up the generated token if a callback is provided std::invoke(std::forward(callback_), - new_token_id, new_token_logits, is_eos, n_decoded_tokens); + new_token_id, + new_token_logits, + is_eos || effective_n_decoded_tokens == max_new_tokens, + effective_n_decoded_tokens); batch = llama_batch_get_one(&new_token_id, 1); }