mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
feat(backend): correctly handle the max_new_tokens case for is_eos
This commit is contained in:
parent
05ff551950
commit
06424aa9ff
@ -113,6 +113,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||||
auto new_token_logits = 0.0f; // TODO: return logit
|
auto new_token_logits = 0.0f; // TODO: return logit
|
||||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||||
|
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
|
||||||
|
|
||||||
if (!generation_context.generation_params.ignore_eos_token) {
|
if (!generation_context.generation_params.ignore_eos_token) {
|
||||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||||
@ -121,7 +122,10 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||||
new_token_id, new_token_logits, is_eos, n_decoded_tokens);
|
new_token_id,
|
||||||
|
new_token_logits,
|
||||||
|
is_eos || effective_n_decoded_tokens == max_new_tokens,
|
||||||
|
effective_n_decoded_tokens);
|
||||||
|
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user