mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
feat(backend): add some initial decoding steps
This commit is contained in:
parent
098c66920d
commit
45d5a6a8c5
@ -2,20 +2,23 @@
|
|||||||
// Created by Morgan Funtowicz on 9/28/2024.
|
// Created by Morgan Funtowicz on 9/28/2024.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <expected>
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <span>
|
||||||
|
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
#include <fmt/chrono.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <fmt/std.h>
|
#include <fmt/std.h>
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llama {
|
||||||
|
|
||||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
||||||
CreateLlamaCppBackend(const std::filesystem::path& modelPath) {
|
CreateLlamaCppBackend(const std::filesystem::path& modelPath) {
|
||||||
SPDLOG_INFO(FMT_STRING("Loading model from {}"), modelPath);
|
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||||
|
|
||||||
@ -28,39 +31,109 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
auto* model = llama_load_model_from_file(modelPath.c_str(), params);
|
auto* model = llama_load_model_from_file(modelPath.c_str(), params);
|
||||||
auto* context = llama_new_context_with_model(model, {
|
auto* context = llama_new_context_with_model(model, {
|
||||||
.n_batch = 1,
|
.n_batch = 1,
|
||||||
|
.n_threads = 16,
|
||||||
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
|
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
|
||||||
.flash_attn = true,
|
.flash_attn = false,
|
||||||
});
|
});
|
||||||
|
|
||||||
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
|
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
||||||
: model(model), ctx(ctx), batch() {
|
: model(model), ctx(ctx) {
|
||||||
char modelName[128];
|
#ifndef NDEBUG
|
||||||
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
char modelName[256];
|
||||||
|
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
|
||||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||||
if (model) {
|
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp model");
|
|
||||||
llama_free_model(model);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp context");
|
SPDLOG_DEBUG("Freeing llama.cpp context");
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(model) {
|
||||||
|
SPDLOG_DEBUG("Freeing llama.cpp model");
|
||||||
|
llama_free_model(model);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void huggingface::tgi::backends::llama::TgiLlamaCppBackend::schedule() {
|
std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const {
|
||||||
std::vector<llama_token> tokens;
|
std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx));
|
||||||
|
|
||||||
|
if(auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); nTokens < 0){
|
||||||
|
tokens.resize(-nTokens);
|
||||||
|
llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true);
|
||||||
|
} else {
|
||||||
|
tokens.resize(nTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace impl {
|
SPDLOG_DEBUG(FMT_STRING("Tokenized input with {:d} tokens"), tokens.size());
|
||||||
class LlamaCppBackendImpl {
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs(
|
||||||
|
const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty, const uint64_t seed) {
|
||||||
|
auto *sampler = llama_sampler_chain_init({.no_perf = false});
|
||||||
|
|
||||||
|
// Penalties
|
||||||
|
llama_sampler_chain_add(sampler, llama_sampler_init_penalties(
|
||||||
|
llama_n_vocab(model),
|
||||||
|
llama_token_eos(model),
|
||||||
|
llama_token_nl (model),
|
||||||
|
0.0f,
|
||||||
|
repetitionPenalty,
|
||||||
|
frequencyPenalty,
|
||||||
|
0.0f,
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
));
|
||||||
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
|
||||||
|
|
||||||
|
if(0 < topP && topP < 1) {
|
||||||
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed));
|
||||||
|
return std::make_unique<llama_sampler*>(sampler);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TgiLlamaCppBackend::TokenId> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
||||||
|
std::span<const TokenId> tokens, const uint32_t topK, const float_t topP, const uint32_t maxNewTokens) {
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
||||||
|
|
||||||
|
// Allocate generation result
|
||||||
|
std::vector<TgiLlamaCppBackend::TokenId> generated;
|
||||||
|
generated.reserve(llama_n_seq_max(ctx) - tokens.size());
|
||||||
|
|
||||||
|
// Retrieve decoding context
|
||||||
|
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size()));
|
||||||
|
auto sampler = GetSamplerFromArgs(topK, topP, 1.0, 1.0, 2014);
|
||||||
|
|
||||||
|
// Decode
|
||||||
|
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
const auto start = std::chrono::steady_clock::now();
|
||||||
|
const auto status = llama_decode(ctx, batch);
|
||||||
|
const auto end = std::chrono::steady_clock::now();
|
||||||
|
const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency);
|
||||||
|
#else
|
||||||
|
const auto status = llama_decode(ctx, batch);
|
||||||
|
#endif
|
||||||
|
if (status == LLAMA_SUCCESS) {
|
||||||
|
// Sample the new token
|
||||||
|
auto new_token_id = llama_sampler_sample(*sampler, ctx, -1);
|
||||||
|
generated.emplace_back(new_token_id);
|
||||||
|
generating = !llama_token_is_eog(model, new_token_id);
|
||||||
|
|
||||||
|
// Next iteration
|
||||||
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
generated.shrink_to_fit();
|
||||||
|
return generated;
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -4,28 +4,61 @@
|
|||||||
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||||
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <expected>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
#define LLAMA_SUCCESS 0
|
||||||
// const char* TGI_BACKEND_LLAMA_CPP_NAME = "llama.cpp";
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::llama {
|
||||||
enum TgiLlamaCppBackendError {
|
enum TgiLlamaCppBackendError {
|
||||||
MODEL_FILE_DOESNT_EXIST = 1
|
MODEL_FILE_DOESNT_EXIST = 1
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class TgiLlamaCppBackend {
|
class TgiLlamaCppBackend {
|
||||||
|
using TokenId = int32_t;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llama_model* model;
|
llama_model* model;
|
||||||
llama_context* ctx;
|
llama_context* ctx;
|
||||||
llama_batch batch;
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::unique_ptr<llama_sampler *> GetSamplerFromArgs(
|
||||||
|
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
||||||
~TgiLlamaCppBackend();
|
~TgiLlamaCppBackend();
|
||||||
|
|
||||||
void schedule();
|
/**
|
||||||
|
*
|
||||||
|
* @param text
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param maxNewTokens
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Generate(
|
||||||
|
std::span<const TokenId> tokens,
|
||||||
|
uint32_t topK,
|
||||||
|
float_t topP = 1.0f,
|
||||||
|
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max()
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
||||||
|
@ -3,21 +3,37 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
#include <fmt/color.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <fmt/std.h>
|
#include <fmt/std.h>
|
||||||
#include <fmt/color.h>
|
#include <fmt/ranges.h>
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
#include "../csrc/backend.hpp"
|
#include "../csrc/backend.hpp"
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
if(argc < 2) {
|
if (argc < 2) {
|
||||||
fmt::print("No model folder provider");
|
fmt::print("No model folder provider");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
spdlog::set_level(spdlog::level::debug);
|
spdlog::set_level(spdlog::level::debug);
|
||||||
|
|
||||||
|
const auto prompt = "My name is Morgan";
|
||||||
|
|
||||||
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
||||||
if(auto backend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(modelPath); backend.has_value())
|
if (auto maybeBackend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) {
|
||||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::yellow), "Successfully initialized llama.cpp model from {}\n", modelPath);
|
// Retrieve the backend
|
||||||
|
const auto& backend = *maybeBackend;
|
||||||
|
|
||||||
|
// Generate
|
||||||
|
const auto promptTokens = backend->Tokenize(prompt);
|
||||||
|
const auto out = backend->Generate(promptTokens, 30, 1.0, 32);
|
||||||
|
fmt::print(FMT_STRING("Generated: {}"), out);
|
||||||
|
} else {
|
||||||
|
switch (maybeBackend.error()) {
|
||||||
|
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
||||||
|
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
|
||||||
|
return maybeBackend.error();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user