feat(llamacpp): wip explosion

This commit is contained in:
Morgan Funtowicz 2024-10-29 22:30:36 +01:00
parent a316c53255
commit 0c1dd0ed2b
3 changed files with 120 additions and 159 deletions

View File

@ -15,82 +15,15 @@
#include "backend.hpp" #include "backend.hpp"
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
[[nodiscard]]
std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath, const uint16_t nThreads) noexcept {
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
llama_backend_init(); std::unique_ptr<llama_sampler> SamplingParams::IntoLlamaSampler(const llama_model *pModel) const {
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); auto *pSampler = llama_sampler_chain_init({.no_perf = false});
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
llama_print_system_info();
#endif
// Load the model
if (!exists(modelPath)) {
return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST);
}
auto params = llama_model_default_params();
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
auto *context = llama_new_context_with_model(model, {
.n_batch = 1,
.n_threads = nThreads,
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
.flash_attn = false,
});
return std::make_pair(model, context);
}
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
llama_context *const ctx)
: model(model), ctx(ctx) {
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
char modelName[256];
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
#endif
}
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
if (ctx) {
SPDLOG_DEBUG("Freeing llama.cpp context");
llama_free(ctx);
}
if (model) {
SPDLOG_DEBUG("Freeing llama.cpp model");
llama_free_model(model);
}
}
std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const {
std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx));
if (auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true,
true); nTokens < 0) {
tokens.resize(-nTokens);
llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true);
} else {
tokens.resize(nTokens);
}
SPDLOG_DEBUG(FMT_STRING("Tokenized input with {:d} tokens"), tokens.size());
return tokens;
}
std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs(
const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty,
const uint64_t seed) {
auto *sampler = llama_sampler_chain_init({.no_perf = false});
// Penalties // Penalties
llama_sampler_chain_add(sampler, llama_sampler_init_penalties( llama_sampler_chain_add(pSampler, llama_sampler_init_penalties(
llama_n_vocab(model), llama_n_vocab(pModel),
llama_token_eos(model), llama_token_eos(pModel),
llama_token_nl(model), llama_token_nl(pModel),
0.0f, 0.0f,
repetitionPenalty, repetitionPenalty,
frequencyPenalty, frequencyPenalty,
@ -98,41 +31,43 @@ namespace huggingface::tgi::backends::llamacpp {
false, false,
false false
)); ));
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK))); llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
if (0 < topP && topP < 1) { if (0 < topP && topP < 1) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1)); llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(topP, 1));
} }
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed)); llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
return std::make_unique<llama_sampler *>(sampler); return std::unique_ptr<llama_sampler>(pSampler);
} }
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> Worker::Worker(std::shared_ptr<llama_model> pModel, const llama_context_params &params)
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::Generate( : mModel_(pModel), mParams_(params) {
std::span<const TokenId> tokens,
const uint32_t topK,
const float_t topP,
const float_t frequencyPenalty,
const float_t repetitionPenalty,
const uint32_t maxNewTokens,
const uint64_t seed
) {
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
// Allocate generation result #ifdef TGI_LLAMACPP_BACKEND_DEBUG
std::vector<TgiLlamaCppBackend::TokenId> generated; char modelName[256];
generated.reserve(llama_n_seq_max(ctx) - tokens.size()); llama_model_meta_val_str(pModel.get(), "general.name", modelName, sizeof(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
#endif
}
void Worker::Loop(std::atomic_flag &running, std::atomic_uint8_t &waiting, std::queue<SamplingParams> &backlog) {
auto *context = llama_new_context_with_model(mModel_.get(), mParams_);
while (running.test(std::memory_order_acquire)) {
if (waiting.load(std::memory_order_acquire) > 0) {
--waiting;
auto request = backlog.front();
auto sampler = request.IntoLlamaSampler(mModel_.get());
// Retrieve decoding context // Retrieve decoding context
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size())); auto batch = llama_batch_get_one(tokens.data(), tokens.size());
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
// Decode // Decode
for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) { for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < 1; ++nDecoded) {
#ifdef TGI_LLAMACPP_BACKEND_DEBUG #ifdef TGI_LLAMACPP_BACKEND_DEBUG
const auto start = std::chrono::steady_clock::now(); const auto start = std::chrono::steady_clock::now();
const auto status = llama_decode(ctx, batch); const auto status = llama_decode(context, batch);
const auto end = std::chrono::steady_clock::now(); const auto end = std::chrono::steady_clock::now();
const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency);
@ -141,14 +76,25 @@ namespace huggingface::tgi::backends::llamacpp {
#endif #endif
if (LLAMA_SUCCESS(status)) { if (LLAMA_SUCCESS(status)) {
// Sample the new token // Sample the new token
auto new_token_id = llama_sampler_sample(*sampler, ctx, -1); auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
generated.emplace_back(new_token_id); generated.emplace_back(new_token_id);
generating = !llama_token_is_eog(model, new_token_id); generating = !llama_token_is_eog(mModel_.get(), new_token_id);
// Next iteration // Next iteration
batch = llama_batch_get_one(&new_token_id, 1); batch = llama_batch_get_one(&new_token_id, 1);
} }
} }
return generated;
backlog.pop();
} }
}
llama_free(context);
}
huggingface::tgi::backends::llamacpp::BackendBase::BackendBase(llama_model *model)
: mModel_(model, llama_free_model) { llama_backend_init(); }
BackendBase::~BackendBase() { llama_backend_free(); }
} }

View File

@ -4,9 +4,11 @@
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP #ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP #define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
#include <atomic>
#include <cmath> #include <cmath>
#include <expected> #include <expected>
#include <filesystem> #include <filesystem>
#include <queue>
#include <memory> #include <memory>
#include <span> #include <span>
#include <vector> #include <vector>
@ -16,72 +18,85 @@
#define LLAMA_SUCCESS(x) x == 0 #define LLAMA_SUCCESS(x) x == 0
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
enum TgiLlamaCppBackendError : uint8_t { enum BackendError : uint8_t {
MODEL_FILE_DOESNT_EXIST = 1 MODEL_FILE_DOESNT_EXIST = 1
}; };
class TgiLlamaCppBackend { struct SamplingParams {
using TokenId = llama_token; uint32_t topK = std::numeric_limits<decltype(topK)>::max();
float_t topP = 1.0f;
private: float_t frequencyPenalty = 0.0f;
llama_model *model; float_t repetitionPenalty = 0.0f;
llama_context *ctx; uint64_t seed = 2014;
/** /**
* * Convert this GenerationParams to the respective llama_sampler structure
* @param topK * @param Pointer to the model data
* @param topP
* @return * @return
*/ */
std::unique_ptr<llama_sampler *> GetSamplerFromArgs( std::unique_ptr<llama_sampler> IntoLlamaSampler(const llama_model *) const;
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed); };
class Worker {
protected:
constexpr static auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
public: public:
/** using model_ptr_type = std::shared_ptr<llama_model>;
* using context_params_type = llama_context_params;
* @return using token_id_type = llama_token;
*/
static std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
FromGGUF(const std::filesystem::path &, uint16_t) noexcept;
TgiLlamaCppBackend(llama_model *model, llama_context *ctx); private:
const model_ptr_type mModel_;
context_params_type mParams_;
~TgiLlamaCppBackend(); public:
Worker(std::shared_ptr<llama_model> pModel, const llama_context_params &params);
/** void Loop(std::atomic_flag &, std::atomic_uint8_t &, std::queue<SamplingParams> &) const;
* };
* @param text
* @return
*/ class BackendBase {
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string &text) const; private:
std::shared_ptr<llama_model> mModel_;
public:
explicit BackendBase(llama_model *model);
~BackendBase();
/** /**
* *
* @param tokens * @param tokens
* @param topK * @params out
* @param topP * @param params
* @param frequencyPenalty
* @param repetitionPenalty
* @param maxNewTokens * @param maxNewTokens
* @param seed
* @return * @return
*/ */
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> Generate( std::expected<std::vector<llama_token>, BackendError> Generate(
std::span<const TokenId> tokens, std::span<const llama_token> tokens,
uint32_t topK, std::span<llama_token> out,
float_t topP = 1.0f, const SamplingParams &params,
float_t frequencyPenalty = 0.0f, uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1
float_t repetitionPenalty = 0.0f, );
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1,
uint64_t seed = 2014 /**
*
* @param tokens
* @param params
* @param maxNewTokens
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
std::expected<std::vector<llama_token>, BackendError> Generate(
std::span<const llama_token> tokens,
const SamplingParams &params,
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1
); );
}; };
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
CreateLlamaCppBackend(const std::filesystem::path &root);
} }
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP #endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP

View File

@ -28,10 +28,10 @@ namespace huggingface::tgi::backends::llamacpp::impl {
class LlamaCppBackendImpl { class LlamaCppBackendImpl {
private: private:
TgiLlamaCppBackend _inner; BackendBase _inner;
public: public:
LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {} LlamaCppBackendImpl(llama_model *model) : _inner(model) {}
}; };
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) { std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) {