feat(backend): build and link through build.rs

This commit is contained in:
Morgan Funtowicz 2024-10-24 16:42:50 +02:00
parent 355d8a55b4
commit e4d803c94e
11 changed files with 296 additions and 114 deletions

90
Cargo.lock generated
View File

@ -2732,6 +2732,20 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "opentelemetry"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "570074cc999d1a58184080966e5bd3bf3a9a4af650c3b05047c2621e7405cd17"
dependencies = [
"futures-core",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
]
[[package]] [[package]]
name = "opentelemetry-otlp" name = "opentelemetry-otlp"
version = "0.13.0" version = "0.13.0"
@ -2849,6 +2863,24 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "opentelemetry_sdk"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2c627d9f4c9cdc1f21a29ee4bfbd6028fcb8bcf2a857b43f3abdf72c9c862f3"
dependencies = [
"async-trait",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry 0.26.0",
"percent-encoding",
"rand",
"thiserror",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
@ -4187,12 +4219,14 @@ dependencies = [
name = "text-generation-backend-llamacpp" name = "text-generation-backend-llamacpp"
version = "2.4.1-dev0" version = "2.4.1-dev0"
dependencies = [ dependencies = [
"async-trait",
"clap 4.5.20", "clap 4.5.20",
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"hf-hub", "hf-hub",
"image", "image",
"log",
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"pkg-config", "pkg-config",
@ -4202,6 +4236,10 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing",
"tracing-opentelemetry 0.27.0",
"tracing-subscriber",
"utoipa 5.1.2",
] ]
[[package]] [[package]]
@ -4330,7 +4368,7 @@ dependencies = [
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"ureq", "ureq",
"utoipa", "utoipa 4.2.3",
"utoipa-swagger-ui", "utoipa-swagger-ui",
"uuid", "uuid",
"vergen", "vergen",
@ -4381,7 +4419,7 @@ dependencies = [
"tracing", "tracing",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa 4.2.3",
"utoipa-swagger-ui", "utoipa-swagger-ui",
] ]
@ -4432,7 +4470,7 @@ dependencies = [
"tracing", "tracing",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa 4.2.3",
"utoipa-swagger-ui", "utoipa-swagger-ui",
] ]
@ -4946,6 +4984,24 @@ dependencies = [
"web-time 1.1.0", "web-time 1.1.0",
] ]
[[package]]
name = "tracing-opentelemetry"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc58af5d3f6c5811462cabb3289aec0093f7338e367e5a33d28c0433b3c7360b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry 0.26.0",
"opentelemetry_sdk 0.26.0",
"smallvec",
"tracing",
"tracing-core",
"tracing-log 0.2.0",
"tracing-subscriber",
"web-time 1.1.0",
]
[[package]] [[package]]
name = "tracing-opentelemetry-instrumentation-sdk" name = "tracing-opentelemetry-instrumentation-sdk"
version = "0.16.0" version = "0.16.0"
@ -5136,7 +5192,19 @@ dependencies = [
"indexmap 2.6.0", "indexmap 2.6.0",
"serde", "serde",
"serde_json", "serde_json",
"utoipa-gen", "utoipa-gen 4.3.0",
]
[[package]]
name = "utoipa"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e12e84f0ff45b6818029cd0f67280e453c80132c1b9897df407ecc20b9f7cfd"
dependencies = [
"indexmap 2.5.0",
"serde",
"serde_json",
"utoipa-gen 5.1.2",
] ]
[[package]] [[package]]
@ -5152,6 +5220,18 @@ dependencies = [
"syn 2.0.85", "syn 2.0.85",
] ]
[[package]]
name = "utoipa-gen"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dfc694d3a3118d2b9e80d68be83bf1aab7988510916934db83da61c14e7e6b2"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.79",
]
[[package]] [[package]]
name = "utoipa-swagger-ui" name = "utoipa-swagger-ui"
version = "6.0.0" version = "6.0.0"
@ -5164,7 +5244,7 @@ dependencies = [
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",
"utoipa", "utoipa 4.2.3",
"zip", "zip",
] ]

View File

@ -11,12 +11,12 @@ set(LLAMA_CPP_TARGET_CUDA_ARCHS "75-real;80-real;86-real;89-real;90-real" CACHE
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner") option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp") option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
message(STATUS "Targeting libc++") message(STATUS "Targeting libc++")
set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS})
else() else ()
message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}") message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}")
endif() endif ()
# Add dependencies # Add dependencies
include(cmake/fmt.cmake) include(cmake/fmt.cmake)
@ -42,15 +42,17 @@ fetchcontent_declare(
fetchcontent_makeavailable(llama) fetchcontent_makeavailable(llama)
add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp) add_library(tgi_llamacpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11) target_compile_features(tgi_llamacpp_backend_impl PRIVATE cxx_std_11)
target_link_libraries(tgi_llama_cpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common) target_link_libraries(tgi_llamacpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common)
install(TARGETS tgi_llamacpp_backend_impl spdlog llama common)
if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER}) if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
message(STATUS "Building llama.cpp offline runner") message(STATUS "Building llama.cpp offline runner")
add_executable(tgi_llama_cpp_offline_runner offline/main.cpp) add_executable(tgi_llama_cppoffline_runner offline/main.cpp)
target_link_libraries(tgi_llama_cpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common) target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common)
endif () endif ()

View File

@ -6,6 +6,7 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1"
clap = { version = "4.5.19", features = ["derive"] } clap = { version = "4.5.19", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hf-hub = { workspace = true } hf-hub = { workspace = true }
@ -18,6 +19,11 @@ thiserror = "1.0.64"
tokio = "1.40.0" tokio = "1.40.0"
tokio-stream = "0.1.16" tokio-stream = "0.1.16"
tokenizers = { workspace = true } tokenizers = { workspace = true }
tracing = "0.1"
tracing-opentelemetry = "0.27.0"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
utoipa = { version = "5.1.2", features = ["axum_extras"] }
log = "0.4.22"
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -1,12 +1,14 @@
use cxx_build::CFG; use cxx_build::CFG;
use std::env; use std::env;
use std::path::PathBuf; use std::path::{Path, PathBuf};
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real"; const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl"; const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llamacpp_backend_impl";
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; const CMAKE_LLAMA_CPP_FFI_TARGET: &str = "tgi_llamacpp_backend";
const MPI_REQUIRED_VERSION: &str = "4.1"; const MPI_REQUIRED_VERSION: &str = "4.1";
const BACKEND_DEPS: [&str; 2] = [CMAKE_LLAMA_CPP_TARGET, CMAKE_LLAMA_CPP_FFI_TARGET];
macro_rules! probe { macro_rules! probe {
($name: expr, $version: expr) => { ($name: expr, $version: expr) => {
if let Err(_) = pkg_config::probe_library($name) { if let Err(_) = pkg_config::probe_library($name) {
@ -16,11 +18,12 @@ macro_rules! probe {
}; };
} }
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf { fn build_backend(
let install_path = env::var("CMAKE_INSTALL_PREFIX") is_debug: bool,
.map(|val| PathBuf::from(val)) opt_level: &str,
.unwrap_or(out_dir.join("dist")); out_dir: &Path,
install_path: &PathBuf,
) -> PathBuf {
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF"); let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
let cuda_archs = let cuda_archs =
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS); option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
@ -38,41 +41,28 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
.define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs) .define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
.build(); .build();
// Additional transitive CMake dependencies let lib_path = install_path.join("lib64");
let deps_folder = out_dir.join("build").join("_deps"); println!("cargo:rustc-link-search=native={}", lib_path.display());
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_name = match is_debug {
true => format!("{}d", dependency),
false => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
let deps_folder = out_dir.join("build").join("_deps"); let deps_folder = out_dir.join("build").join("_deps");
deps_folder deps_folder
} }
fn build_ffi_layer(deps_folder: &PathBuf) { fn build_ffi_layer(deps_folder: &Path, install_prefix: &Path) {
println!("cargo:warning={}", &deps_folder.display()); println!("cargo:warning={}", deps_folder.display());
CFG.include_prefix = "backends/llamacpp"; CFG.include_prefix = "backends/llamacpp";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
.std("c++23") .std("c++23")
.include(deps_folder.join("fmt-src").join("include")) .include(deps_folder.join("spdlog-src").join("include")) // Why spdlog doesnt install headers?
.include(deps_folder.join("spdlog-src").join("include")) // .include(deps_folder.join("fmt-src").join("include")) // Why spdlog doesnt install headers?
.include(deps_folder.join("llama-src").join("common")) // .include(deps_folder.join("llama-src").join("include")) // Why spdlog doesnt install headers?
.include(deps_folder.join("llama-src").join("ggml").join("include")) .include(deps_folder.join("llama-src").join("ggml").join("include")) // Why spdlog doesnt install headers?
.include(deps_folder.join("llama-src").join("include")) .include(deps_folder.join("llama-src").join("common").join("include")) // Why spdlog doesnt install headers?
.include("csrc/backend.hpp") .include(install_prefix.join("include"))
.file("csrc/ffi.cpp") .include("csrc")
.compile(CMAKE_LLAMA_CPP_TARGET); .file("csrc/ffi.hpp")
.compile(CMAKE_LLAMA_CPP_FFI_TARGET);
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=csrc/backend.hpp");
println!("cargo:rerun-if-changed=csrc/backend.cpp");
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
} }
fn main() { fn main() {
@ -84,17 +74,35 @@ fn main() {
_ => (false, "3"), _ => (false, "3"),
}; };
let install_path = env::var("CMAKE_INSTALL_PREFIX")
.map(|val| PathBuf::from(val))
.unwrap_or(out_dir.join("dist"));
// Build the backend // Build the backend
let deps_folder = build_backend(is_debug, opt_level, &out_dir); let deps_path = build_backend(is_debug, opt_level, out_dir.as_path(), &install_path);
// Build the FFI layer calling the backend above // Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder); build_ffi_layer(&deps_path, &install_path);
// Emit linkage search path // Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION); probe!("ompi", MPI_REQUIRED_VERSION);
// Backend // Backend
// BACKEND_DEPS.iter().for_each(|name| { BACKEND_DEPS.iter().for_each(|name| {
// println!("cargo:rustc-link-lib=static={}", name); println!("cargo:rustc-link-lib=static={}", name);
// }); });
// Linkage info
println!("cargo:rustc-link-search=native={}", out_dir.display());
println!("cargo:rustc-link-lib=static=fmtd");
println!("cargo:rustc-link-lib=static=spdlogd");
println!("cargo:rustc-link-lib=static=common");
println!("cargo:rustc-link-lib=dylib=ggml");
println!("cargo:rustc-link-lib=dylib=llama");
// Rerun if one of these file change
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=csrc/backend.hpp");
println!("cargo:rerun-if-changed=csrc/backend.cpp");
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
} }

View File

@ -14,33 +14,35 @@
#include "backend.hpp" #include "backend.hpp"
namespace huggingface::tgi::backends::llama { namespace huggingface::tgi::backends::llamacpp {
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError> [[nodiscard]]
CreateLlamaCppBackend(const std::filesystem::path& modelPath) { std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath) noexcept {
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath); SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
llama_backend_init(); llama_backend_init();
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
// Load the model // Load the model
if(!exists(modelPath)) { if (!exists(modelPath)) {
return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST); return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST);
} }
auto params = llama_model_default_params(); auto params = llama_model_default_params();
auto* model = llama_load_model_from_file(modelPath.c_str(), params); auto *model = llama_load_model_from_file(modelPath.c_str(), params);
auto* context = llama_new_context_with_model(model, { auto *context = llama_new_context_with_model(model, {
.n_batch = 1, .n_batch = 1,
.n_threads = 16, .n_threads = 16,
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL, .attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
.flash_attn = false, .flash_attn = false,
}); });
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context); return std::make_pair(model, context);
} }
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx) huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
: model(model), ctx(ctx) { llama_context *const ctx)
: model(model), ctx(ctx) {
#ifndef NDEBUG #ifndef NDEBUG
char modelName[256]; char modelName[256];
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName)); llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
@ -48,13 +50,13 @@ namespace huggingface::tgi::backends::llama {
#endif #endif
} }
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() { huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
if (ctx) { if (ctx) {
SPDLOG_DEBUG("Freeing llama.cpp context"); SPDLOG_DEBUG("Freeing llama.cpp context");
llama_free(ctx); llama_free(ctx);
} }
if(model) { if (model) {
SPDLOG_DEBUG("Freeing llama.cpp model"); SPDLOG_DEBUG("Freeing llama.cpp model");
llama_free_model(model); llama_free_model(model);
} }
@ -63,7 +65,8 @@ namespace huggingface::tgi::backends::llama {
std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const { std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const {
std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx)); std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx));
if(auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); nTokens < 0){ if (auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true,
true); nTokens < 0) {
tokens.resize(-nTokens); tokens.resize(-nTokens);
llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true);
} else { } else {
@ -75,14 +78,15 @@ namespace huggingface::tgi::backends::llama {
} }
std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs( std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs(
const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty, const uint64_t seed) { const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty,
const uint64_t seed) {
auto *sampler = llama_sampler_chain_init({.no_perf = false}); auto *sampler = llama_sampler_chain_init({.no_perf = false});
// Penalties // Penalties
llama_sampler_chain_add(sampler, llama_sampler_init_penalties( llama_sampler_chain_add(sampler, llama_sampler_init_penalties(
llama_n_vocab(model), llama_n_vocab(model),
llama_token_eos(model), llama_token_eos(model),
llama_token_nl (model), llama_token_nl(model),
0.0f, 0.0f,
repetitionPenalty, repetitionPenalty,
frequencyPenalty, frequencyPenalty,
@ -92,15 +96,16 @@ namespace huggingface::tgi::backends::llama {
)); ));
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK))); llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
if(0 < topP && topP < 1) { if (0 < topP && topP < 1) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1)); llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1));
} }
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed)); llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed));
return std::make_unique<llama_sampler*>(sampler); return std::make_unique<llama_sampler *>(sampler);
} }
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate( std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError>
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::Generate(
std::span<const TokenId> tokens, std::span<const TokenId> tokens,
const uint32_t topK, const uint32_t topK,
const float_t topP, const float_t topP,
@ -108,7 +113,7 @@ namespace huggingface::tgi::backends::llama {
const float_t repetitionPenalty, const float_t repetitionPenalty,
const uint32_t maxNewTokens, const uint32_t maxNewTokens,
const uint64_t seed const uint64_t seed
) { ) {
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size()); SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
// Allocate generation result // Allocate generation result
@ -120,7 +125,7 @@ namespace huggingface::tgi::backends::llama {
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed); auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
// Decode // Decode
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) { for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
#ifndef NDEBUG #ifndef NDEBUG
const auto start = std::chrono::steady_clock::now(); const auto start = std::chrono::steady_clock::now();
const auto status = llama_decode(ctx, batch); const auto status = llama_decode(ctx, batch);

View File

@ -9,12 +9,14 @@
#include <filesystem> #include <filesystem>
#include <memory> #include <memory>
#include <span> #include <span>
#include <vector>
#include <llama.h> #include <llama.h>
#define LLAMA_SUCCESS(x) x == 0 #define LLAMA_SUCCESS(x) x == 0
namespace huggingface::tgi::backends::llama { namespace huggingface::tgi::backends::llamacpp {
enum TgiLlamaCppBackendError: uint8_t { enum TgiLlamaCppBackendError : uint8_t {
MODEL_FILE_DOESNT_EXIST = 1 MODEL_FILE_DOESNT_EXIST = 1
}; };
@ -22,8 +24,8 @@ namespace huggingface::tgi::backends::llama {
using TokenId = llama_token; using TokenId = llama_token;
private: private:
llama_model* model; llama_model *model;
llama_context* ctx; llama_context *ctx;
/** /**
* *
@ -35,7 +37,15 @@ namespace huggingface::tgi::backends::llama {
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed); uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed);
public: public:
/**
*
* @return
*/
static std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
FromGGUF(const std::filesystem::path &) noexcept;
TgiLlamaCppBackend(llama_model *model, llama_context *ctx); TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
~TgiLlamaCppBackend(); ~TgiLlamaCppBackend();
/** /**
@ -44,7 +54,7 @@ namespace huggingface::tgi::backends::llama {
* @return * @return
*/ */
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]] [[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const; std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string &text) const;
/** /**
* *
@ -71,7 +81,7 @@ namespace huggingface::tgi::backends::llama {
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]] [[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError> std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
CreateLlamaCppBackend(const std::filesystem::path& root); CreateLlamaCppBackend(const std::filesystem::path &root);
} }
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP #endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP

View File

@ -5,14 +5,44 @@
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP #ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP #define TGI_LLAMA_CPP_BACKEND_FFI_HPP
#include <exception>
#include <filesystem>
#include <string_view>
#include <spdlog/spdlog.h>
#include "backend.hpp" #include "backend.hpp"
//#include "backends/llamacpp/src/lib.rs.h"
namespace huggingface::tgi::backends::llamacpp::impl {
class LlamaCppBackendImpl;
}
namespace huggingface::tgi::backends::llama { #include "backends/llamacpp/src/lib.rs.h"
class LlamaCppBackendImpl {
namespace huggingface::tgi::backends::llamacpp::impl {
class LlamaCppBackendException : std::exception {
}; };
class LlamaCppBackendImpl {
private:
TgiLlamaCppBackend _inner;
public:
LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {}
};
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath) {
const auto cxxPath = std::string_view(modelPath);
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath)); maybe.has_value()) {
auto [model, context] = *maybe;
return std::make_unique<LlamaCppBackendImpl>(model, context);
} else {
throw LlamaCppBackendException();
}
}
} }

View File

@ -10,6 +10,8 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "../csrc/backend.hpp" #include "../csrc/backend.hpp"
using namespace huggingface::tgi::backends::llamacpp;
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (argc < 2) { if (argc < 2) {
fmt::print("No model folder provider"); fmt::print("No model folder provider");
@ -21,7 +23,7 @@ int main(int argc, char** argv) {
const auto prompt = "My name is Morgan"; const auto prompt = "My name is Morgan";
const auto modelPath = absolute(std::filesystem::path(argv[1])); const auto modelPath = absolute(std::filesystem::path(argv[1]));
if (auto maybeBackend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) { if (auto maybeBackend = CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) {
// Retrieve the backend // Retrieve the backend
const auto& backend = *maybeBackend; const auto& backend = *maybeBackend;
@ -38,7 +40,7 @@ int main(int argc, char** argv) {
} else { } else {
switch (maybeBackend.error()) { switch (maybeBackend.error()) {
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST: case TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath); fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
return maybeBackend.error(); return maybeBackend.error();
} }

View File

@ -1,31 +1,66 @@
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl}; use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use std::path::Path; use std::path::{Path, PathBuf};
use std::sync::Arc;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use thiserror::Error;
use tokio::task::spawn_blocking;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::info;
pub struct TgiLlamaCppBakend { unsafe impl Send for LlamaCppBackendImpl {}
backend: UniquePtr<LlamaCppBackendImpl>,
#[derive(Debug, Error)]
pub enum LlamaCppBackendError {
#[error("Provided GGUF model path {0} doesn't exist")]
ModelFileDoesntExist(String),
#[error("Failed to initialize model from GGUF file {0}: {1}")]
ModelInitializationFailed(PathBuf, String),
} }
impl TgiLlamaCppBakend { pub struct LlamaCppBackend {}
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, ()> {
Ok(Self { impl LlamaCppBackend {
backend: create_llamacpp_backend(model_path.as_ref().to_str().unwrap()), pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
}) let path = Arc::new(model_path.as_ref());
if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist(
path.display().to_string(),
));
}
let mut backend = create_llamacpp_backend(path.to_str().unwrap()).map_err(|err| {
LlamaCppBackendError::ModelInitializationFailed(
path.to_path_buf(),
err.what().to_string(),
)
})?;
info!(
"Successfully initialized llama.cpp backend from {}",
path.display()
);
spawn_blocking(move || scheduler_loop(backend));
Ok(Self {})
} }
} }
impl Backend for TgiLlamaCppBakend { async fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {}
#[async_trait]
impl Backend for LlamaCppBackend {
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, _request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Err(InferError::GenerationError("Not implemented yet".into())) Err(InferError::GenerationError("Not implemented yet".into()))
} }
async fn health(&self, current_health: bool) -> bool { async fn health(&self, _: bool) -> bool {
todo!() true
} }
} }

View File

@ -1,6 +1,6 @@
pub mod backend; pub mod backend;
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")] #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")]
mod ffi { mod ffi {
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/llamacpp/csrc/ffi.hpp"); include!("backends/llamacpp/csrc/ffi.hpp");
@ -9,8 +9,8 @@ mod ffi {
type LlamaCppBackendImpl; type LlamaCppBackendImpl;
#[rust_name = "create_llamacpp_backend"] #[rust_name = "create_llamacpp_backend"]
fn CreateLlamaCppBackend( fn CreateLlamaCppBackendImpl(
engine_folder: &str, modelPath: &str,
) -> UniquePtr<LlamaCppBackendImpl>; ) -> Result<UniquePtr<LlamaCppBackendImpl>>;
} }
} }

View File

@ -1,7 +1,8 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf;
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
use text_generation_router::{server, usage_stats}; use text_generation_router::{server, usage_stats};
use thiserror::Error; use thiserror::Error;
use text_generation_router::server::ApiDoc;
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -38,6 +39,8 @@ struct Args {
port: u16, port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
gguf_path: PathBuf,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(long, env)] #[clap(long, env)]
@ -98,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
hostname, hostname,
port, port,
master_shard_uds_path, master_shard_uds_path,
gguf_path,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
@ -116,13 +120,13 @@ async fn main() -> Result<(), RouterError> {
usage_stats, usage_stats,
} = args; } = args;
if let Some(Commands::PrintSchema) = command { // if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi; // use utoipa::OpenApi;
let api_doc = ApiDoc::openapi(); // let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); // let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc); // println!("{}", api_doc);
std::process::exit(0); // std::process::exit(0);
}; // };
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
@ -158,7 +162,7 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let backend = LlamaCppBackend::new(); let backend = LlamaCppBackend::new(gguf_path)?;
// Run server // Run server
server::run( server::run(
@ -185,7 +189,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
) )
.await?; .await?;
Ok(()) Ok(())
} }
@ -194,9 +198,9 @@ enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]
ArgumentValidation(String), ArgumentValidation(String),
#[error("Backend failed: {0}")] #[error("Backend failed: {0}")]
Backend(#[from] V3Error), Backend(#[from] LlamaCppBackendError),
#[error("WebServer error: {0}")] #[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError), WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")] #[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error), Tokio(#[from] std::io::Error),
} }