mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
feat(backend): use the new batch api from llama
This commit is contained in:
parent
274cfce435
commit
8e89793514
@ -38,6 +38,31 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
return {pSampler, llama_sampler_deleter};
|
return {pSampler, llama_sampler_deleter};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::expected<llama_batch, backend_error_t> get_batch_from_prompt(std::span<llama_token> prompt) {
|
||||||
|
auto batch = llama_batch_init(static_cast<int32_t>(prompt.size()), 0, 1);
|
||||||
|
std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) {
|
||||||
|
const auto n_token = batch.n_tokens;
|
||||||
|
|
||||||
|
batch.token[n_token] = token;
|
||||||
|
batch.pos[n_token] = n_token;
|
||||||
|
batch.n_seq_id[n_token] = 1;
|
||||||
|
batch.seq_id[n_token][0] = 1;
|
||||||
|
batch.logits[n_token] = false;
|
||||||
|
batch.n_tokens++;
|
||||||
|
});
|
||||||
|
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
return batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
void update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) {
|
||||||
|
batch.n_tokens = 1;
|
||||||
|
batch.logits[0] = true;
|
||||||
|
batch.token[0] = token;
|
||||||
|
batch.pos[0] = static_cast<int32_t>(position);
|
||||||
|
}
|
||||||
|
|
||||||
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &¶ms)
|
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &¶ms)
|
||||||
: model_(model), context_(llama_new_context_with_model(model_.get(), params)) {
|
: model_(model), context_(llama_new_context_with_model(model_.get(), params)) {
|
||||||
|
|
||||||
@ -59,11 +84,11 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get());
|
auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get());
|
||||||
|
|
||||||
// Set up the prompt
|
// Set up the prompt
|
||||||
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
|
if (auto maybe_batch = get_batch_from_prompt(generation_context.input_tokens); maybe_batch.has_value()) {
|
||||||
auto batch = llama_batch_get_one(copy.data(), copy.size());
|
|
||||||
|
|
||||||
// Decode
|
// Decode
|
||||||
|
auto batch = *maybe_batch;
|
||||||
auto n_decoded_tokens = 0;
|
auto n_decoded_tokens = 0;
|
||||||
|
const auto prompt_size = generation_context.input_tokens.size();
|
||||||
for (bool generating = true; generating; ++n_decoded_tokens) {
|
for (bool generating = true; generating; ++n_decoded_tokens) {
|
||||||
|
|
||||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||||
@ -79,24 +104,30 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
if (LLAMA_SUCCESS(status)) [[likely]] {
|
if (LLAMA_SUCCESS(status)) [[likely]] {
|
||||||
// Sample the new token
|
// Sample the new token
|
||||||
auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1);
|
auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1);
|
||||||
auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
|
const auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
|
||||||
auto new_token_logits = 0.0f; // TODO: return logit
|
const auto new_token_logits = llama_get_logits_ith(context_.get(), -1); // TODO: return logit
|
||||||
|
|
||||||
// Handle termination cases
|
// Handle termination cases
|
||||||
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
const bool has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
||||||
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
const bool has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
||||||
|
const bool is_final = has_reach_max_tokens | has_reach_eog;
|
||||||
generating = !(has_reach_max_tokens | has_reach_eog);
|
|
||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
const auto should_stop =
|
const auto should_stop = callback_(new_token_id, *new_token_logits, is_final, n_decoded_tokens + 1);
|
||||||
std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1);
|
|
||||||
generating ^= should_stop;
|
|
||||||
|
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
// Compute the continuation flag
|
||||||
|
generating = !(should_stop | is_final);
|
||||||
|
|
||||||
|
// Update the batch for the next generation
|
||||||
|
update_batch_for_decoding(batch, new_token_id, prompt_size + n_decoded_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
return n_decoded_tokens;
|
return n_decoded_tokens;
|
||||||
|
} else {
|
||||||
|
return maybe_batch.error();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -75,7 +75,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
struct generation_context_t {
|
struct generation_context_t {
|
||||||
generation_params_t generation_params;
|
generation_params_t generation_params;
|
||||||
sampling_params_t sampling_params;
|
sampling_params_t sampling_params;
|
||||||
std::span<const llama_token> input_tokens;
|
std::span<llama_token> input_tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user