diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 080a4401..daf8de54 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -16,85 +16,156 @@ namespace huggingface::tgi::backends::llamacpp { - std::unique_ptr SamplingParams::IntoLlamaSampler(const llama_model *pModel) const { + void llama_batch_fill_prompt(llama_batch &batch, std::span input_tokens) { + for (auto i = 0; i < input_tokens.size(); ++i) { + batch.token[i] = input_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i] = 0; + batch.logits[i] = false; + ++batch.n_tokens; + } + + batch.logits[batch.n_tokens] = true; + } + + std::unique_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const { auto *pSampler = llama_sampler_chain_init({.no_perf = false}); // Penalties llama_sampler_chain_add(pSampler, llama_sampler_init_penalties( - llama_n_vocab(pModel), - llama_token_eos(pModel), - llama_token_nl(pModel), + llama_n_vocab(model), + llama_token_eos(model), + llama_token_nl(model), 0.0f, - repetitionPenalty, - frequencyPenalty, + repetition_penalty, + frequency_penalty, 0.0f, false, false )); - llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast(topK))); + llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast(top_k))); - if (0 < topP && topP < 1) { - llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(topP, 1)); + if (0 < top_p && top_p < 1) { + llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(top_p, 1)); } llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); return std::unique_ptr(pSampler); } - Worker::Worker(std::shared_ptr pModel, const llama_context_params ¶ms) - : mModel_(pModel), mParams_(params) { + worker_t::worker_t(std::shared_ptr model, const llama_context_params ¶ms) + : mModel_(model), mParams_(params) { #ifdef TGI_LLAMACPP_BACKEND_DEBUG char modelName[256]; - llama_model_meta_val_str(pModel.get(), "general.name", modelName, sizeof(modelName)); + llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName)); SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName)); #endif } - void Worker::Loop(std::atomic_flag &running, std::atomic_uint8_t &waiting, std::queue &backlog) { + void worker_t::loop(std::stop_source &driver, std::queue &backlog) const { auto *context = llama_new_context_with_model(mModel_.get(), mParams_); - while (running.test(std::memory_order_acquire)) { - if (waiting.load(std::memory_order_acquire) > 0) { - --waiting; + while (!driver.stop_requested()) { + const auto generation_context = backlog.front(); - auto request = backlog.front(); - auto sampler = request.IntoLlamaSampler(mModel_.get()); + generate(context, generation_context, std::nullopt); + backlog.pop(); - // Retrieve decoding context - auto batch = llama_batch_get_one(tokens.data(), tokens.size()); - // Decode - for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < 1; ++nDecoded) { -#ifdef TGI_LLAMACPP_BACKEND_DEBUG - const auto start = std::chrono::steady_clock::now(); - const auto status = llama_decode(context, batch); - const auto end = std::chrono::steady_clock::now(); - const auto latency = std::chrono::duration_cast(end - start); - SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); -#else - const auto status = llama_decode(ctx, batch); -#endif - if (LLAMA_SUCCESS(status)) { - // Sample the new token - auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); - generated.emplace_back(new_token_id); - generating = !llama_token_is_eog(mModel_.get(), new_token_id); - - // Next iteration - batch = llama_batch_get_one(&new_token_id, 1); - } - } - - backlog.pop(); - - } + SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size()); } llama_free(context); } - huggingface::tgi::backends::llamacpp::BackendBase::BackendBase(llama_model *model) - : mModel_(model, llama_free_model) { llama_backend_init(); } + size_t worker_t::generate( + llama_context *context, + const generation_context_t &generation_context, + const std::optional &callback) const { + // Store information about context and generation size + auto prompt_length = std::ssize(generation_context.input_tokens); + auto max_new_tokens = generation_context.generation_params.max_new_tokens; - BackendBase::~BackendBase() { llama_backend_free(); } + // Convert sampling params to what llama.cpp is looking for + auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get()); + + // Setup the prompt + auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); + auto batch = llama_batch_get_one(copy.data(), copy.size()); + + // Decode + auto n_decoded_tokens = 0; + for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) { + const auto callback_ = callback.value_or(llama_void_callback); + +#ifdef TGI_LLAMACPP_BACKEND_DEBUG + const auto start = std::chrono::steady_clock::now(); + const auto status = llama_decode(context, batch); + const auto end = std::chrono::steady_clock::now(); + const auto latency = std::chrono::duration_cast(end - start); + SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); +#else + const auto status = llama_decode(ctx, batch); +#endif + batch.n_tokens = 0; + if (LLAMA_SUCCESS(status)) { + // Sample the new token + auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); + auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); + + generation_context.generated_tokens[n_decoded_tokens] = new_token_id; + generating = !is_eos; + + // Bubble up the generated token if a callback is provided + std::invoke(std::forward(callback_), new_token_id, is_eos); + + batch = llama_batch_get_one(&new_token_id, 1); + } + } + + return n_decoded_tokens; + } + + + backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); } + + backend_base_t::~backend_base_t() { llama_backend_free(); } + + std::expected, backend_error_t> backend_base_t::generate( + std::span tokens, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback + ) { + // TODO: Should we provide a way to change this value? + auto generated = std::vector(2 << 8); + + auto nTokensGenerated = generate(tokens, generated, generation_params, sampling_params, callback); + if (nTokensGenerated.has_value()) + generated.resize(*nTokensGenerated); + return generated; + } + + + /** Single worker_t Backend impl **/ + + single_worker_backend_t::single_worker_backend_t(llama_model *model, + const std::optional ¶ms) + : backend_base_t(model), + mContext_(llama_context_factory(model)), + mWorker_(mModel_, params.value_or(llama_context_default_params())) { + llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); + } + + std::expected + single_worker_backend_t::generate( + std::span tokens, + std::span out, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback + ) { + return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback); + } } \ No newline at end of file diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index e4814d45..e7545a3c 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -8,25 +8,42 @@ #include #include #include +#include #include #include +#include #include +#include #include #include +#include #define LLAMA_SUCCESS(x) x == 0 namespace huggingface::tgi::backends::llamacpp { - enum BackendError : uint8_t { + + static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; + typedef std::unique_ptr llama_context_smart_ptr; + + typedef std::function llama_decode_callback; + static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {}; + + /** + * + */ + enum backend_error_t : uint8_t { MODEL_FILE_DOESNT_EXIST = 1 }; - struct SamplingParams { - uint32_t topK = std::numeric_limits::max(); - float_t topP = 1.0f; - float_t frequencyPenalty = 0.0f; - float_t repetitionPenalty = 0.0f; + /** + * + */ + struct sampling_params_t { + uint32_t top_k = std::numeric_limits::max(); + float_t top_p = 1.0f; + float_t frequency_penalty = 0.0f; + float_t repetition_penalty = 0.0f; uint64_t seed = 2014; /** @@ -34,38 +51,72 @@ namespace huggingface::tgi::backends::llamacpp { * @param Pointer to the model data * @return */ - std::unique_ptr IntoLlamaSampler(const llama_model *) const; + std::unique_ptr into_llama_sampler(const llama_model *pModel) const; }; - class Worker { + /** + * + */ + struct generation_params_t { + uint32_t max_new_tokens = std::numeric_limits::max(); + }; + + struct generation_context_t { + generation_params_t generation_params; + sampling_params_t sampling_params; + std::span input_tokens; + std::span generated_tokens; + }; + + /** + * + */ + class worker_t { + private: + const std::shared_ptr mModel_; + const llama_context_params mParams_; + + public: + /** + * + * @param model + * @param params + */ + worker_t(std::shared_ptr model, const llama_context_params ¶ms); + + /** + * + * @param context + * @param generation_context + * @param callback + */ + size_t + generate(llama_context *, const generation_context_t &, const std::optional &) const; + + /** + * + */ + void loop(std::stop_source &driver, std::queue &backlog) const; + }; + + + class backend_base_t { + protected: - constexpr static auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; - - public: - using model_ptr_type = std::shared_ptr; - using context_params_type = llama_context_params; - using token_id_type = llama_token; - - private: - const model_ptr_type mModel_; - context_params_type mParams_; - - public: - Worker(std::shared_ptr pModel, const llama_context_params ¶ms); - - void Loop(std::atomic_flag &, std::atomic_uint8_t &, std::queue &) const; - }; - - - class BackendBase { - - private: std::shared_ptr mModel_; public: - explicit BackendBase(llama_model *model); - ~BackendBase(); + /** + * + * @param model + */ + explicit backend_base_t(llama_model *model); + + /** + * Destructor + */ + ~backend_base_t(); /** * @@ -76,12 +127,13 @@ namespace huggingface::tgi::backends::llamacpp { * @return */ [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] - std::expected, BackendError> Generate( - std::span tokens, - std::span out, - const SamplingParams ¶ms, - uint32_t maxNewTokens = std::numeric_limits::max() - 1 - ); + virtual std::expected generate( + std::span input_tokens, + std::span generated_tokens, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback + ) = 0; /** * @@ -91,12 +143,46 @@ namespace huggingface::tgi::backends::llamacpp { * @return */ [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] - std::expected, BackendError> Generate( + std::expected, backend_error_t> generate( std::span tokens, - const SamplingParams ¶ms, - uint32_t maxNewTokens = std::numeric_limits::max() - 1 + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback = std::nullopt ); }; + + + class single_worker_backend_t : backend_base_t { + private: + constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr { + auto llParams = llama_context_default_params(); + llParams.flash_attn = true; + llParams.n_batch = 1; + llParams.no_perf = true; + llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL; + + return {llama_new_context_with_model(pModel, llParams), llama_context_deleter}; + }; + + llama_context_smart_ptr mContext_; + worker_t mWorker_; + + public: + explicit single_worker_backend_t(llama_model *pModel, const std::optional &); + + using backend_base_t::generate; + + std::expected + generate( + std::span tokens, + std::span out, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback + ) override; + + + }; } #endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP