mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
feat(backend): make eog clearer on c++ side
This commit is contained in:
parent
06424aa9ff
commit
11c593dc69
@ -95,7 +95,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
// Decode
|
// Decode
|
||||||
auto n_decoded_tokens = 0;
|
auto n_decoded_tokens = 0;
|
||||||
for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) {
|
for (bool generating = true; generating; ++n_decoded_tokens) {
|
||||||
const auto callback_ = callback.value_or(llama_void_callback);
|
const auto callback_ = callback.value_or(llama_void_callback);
|
||||||
|
|
||||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||||
@ -108,24 +108,27 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
const auto status = llama_decode(context, batch);
|
const auto status = llama_decode(context, batch);
|
||||||
#endif
|
#endif
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
if (LLAMA_SUCCESS(status)) {
|
if (LLAMA_SUCCESS(status)) [[likely]] {
|
||||||
// Sample the new token
|
// Sample the new token
|
||||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||||
|
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||||
auto new_token_logits = 0.0f; // TODO: return logit
|
auto new_token_logits = 0.0f; // TODO: return logit
|
||||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
|
||||||
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
|
|
||||||
|
|
||||||
if (!generation_context.generation_params.ignore_eos_token) {
|
// Push the token to the generated vector on Rust side
|
||||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||||
generating = !is_eos;
|
|
||||||
}
|
// Handle termination cases
|
||||||
|
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
||||||
|
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
||||||
|
|
||||||
|
generating = !(has_reach_max_tokens | has_reach_eog);
|
||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||||
new_token_id,
|
new_token_id,
|
||||||
new_token_logits,
|
new_token_logits,
|
||||||
is_eos || effective_n_decoded_tokens == max_new_tokens,
|
!generating,
|
||||||
effective_n_decoded_tokens);
|
n_decoded_tokens + 1);
|
||||||
|
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user