feat(backend): correctly handle the max_new_tokens case for is_eos

This commit is contained in:
Morgan Funtowicz 2024-11-03 23:50:46 +01:00
parent 05ff551950
commit 06424aa9ff

View File

@ -113,6 +113,7 @@ namespace huggingface::tgi::backends::llamacpp {
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto new_token_logits = 0.0f; // TODO: return logit auto new_token_logits = 0.0f; // TODO: return logit
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
if (!generation_context.generation_params.ignore_eos_token) { if (!generation_context.generation_params.ignore_eos_token) {
generation_context.generated_tokens[n_decoded_tokens] = new_token_id; generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
@ -121,7 +122,10 @@ namespace huggingface::tgi::backends::llamacpp {
// Bubble up the generated token if a callback is provided // Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_), std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id, new_token_logits, is_eos, n_decoded_tokens); new_token_id,
new_token_logits,
is_eos || effective_n_decoded_tokens == max_new_tokens,
effective_n_decoded_tokens);
batch = llama_batch_get_one(&new_token_id, 1); batch = llama_batch_get_one(&new_token_id, 1);
} }