From e4d803c94ef8a48a172a57f783d2e5f9c9387edd Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 24 Oct 2024 16:42:50 +0200 Subject: [PATCH] feat(backend): build and link through build.rs --- Cargo.lock | 90 ++++++++++++++++++++++++++++-- backends/llamacpp/CMakeLists.txt | 18 +++--- backends/llamacpp/Cargo.toml | 6 ++ backends/llamacpp/build.rs | 86 +++++++++++++++------------- backends/llamacpp/csrc/backend.cpp | 51 +++++++++-------- backends/llamacpp/csrc/backend.hpp | 22 ++++++-- backends/llamacpp/csrc/ffi.hpp | 36 +++++++++++- backends/llamacpp/offline/main.cpp | 6 +- backends/llamacpp/src/backend.rs | 59 ++++++++++++++++---- backends/llamacpp/src/lib.rs | 8 +-- backends/llamacpp/src/main.rs | 28 ++++++---- 11 files changed, 296 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4075556b..479e94d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2732,6 +2732,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "opentelemetry" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "570074cc999d1a58184080966e5bd3bf3a9a4af650c3b05047c2621e7405cd17" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2849,6 +2863,24 @@ dependencies = [ "thiserror", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c627d9f4c9cdc1f21a29ee4bfbd6028fcb8bcf2a857b43f3abdf72c9c862f3" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.26.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -4187,12 +4219,14 @@ dependencies = [ name = "text-generation-backend-llamacpp" version = "2.4.1-dev0" dependencies = [ + "async-trait", "clap 4.5.20", "cmake", "cxx", "cxx-build", "hf-hub", "image", + "log", "metrics", "metrics-exporter-prometheus", "pkg-config", @@ -4202,6 +4236,10 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", + "tracing", + "tracing-opentelemetry 0.27.0", + "tracing-subscriber", + "utoipa 5.1.2", ] [[package]] @@ -4330,7 +4368,7 @@ dependencies = [ "tracing-opentelemetry 0.21.0", "tracing-subscriber", "ureq", - "utoipa", + "utoipa 4.2.3", "utoipa-swagger-ui", "uuid", "vergen", @@ -4381,7 +4419,7 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", - "utoipa", + "utoipa 4.2.3", "utoipa-swagger-ui", ] @@ -4432,7 +4470,7 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", - "utoipa", + "utoipa 4.2.3", "utoipa-swagger-ui", ] @@ -4946,6 +4984,24 @@ dependencies = [ "web-time 1.1.0", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc58af5d3f6c5811462cabb3289aec0093f7338e367e5a33d28c0433b3c7360b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.26.0", + "opentelemetry_sdk 0.26.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", +] + [[package]] name = "tracing-opentelemetry-instrumentation-sdk" version = "0.16.0" @@ -5136,7 +5192,19 @@ dependencies = [ "indexmap 2.6.0", "serde", "serde_json", - "utoipa-gen", + "utoipa-gen 4.3.0", +] + +[[package]] +name = "utoipa" +version = "5.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e12e84f0ff45b6818029cd0f67280e453c80132c1b9897df407ecc20b9f7cfd" +dependencies = [ + "indexmap 2.5.0", + "serde", + "serde_json", + "utoipa-gen 5.1.2", ] [[package]] @@ -5152,6 +5220,18 @@ dependencies = [ "syn 2.0.85", ] +[[package]] +name = "utoipa-gen" +version = "5.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dfc694d3a3118d2b9e80d68be83bf1aab7988510916934db83da61c14e7e6b2" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.79", +] + [[package]] name = "utoipa-swagger-ui" version = "6.0.0" @@ -5164,7 +5244,7 @@ dependencies = [ "rust-embed", "serde", "serde_json", - "utoipa", + "utoipa 4.2.3", "zip", ] diff --git a/backends/llamacpp/CMakeLists.txt b/backends/llamacpp/CMakeLists.txt index 644db5ae..c4b6f0ce 100644 --- a/backends/llamacpp/CMakeLists.txt +++ b/backends/llamacpp/CMakeLists.txt @@ -11,12 +11,12 @@ set(LLAMA_CPP_TARGET_CUDA_ARCHS "75-real;80-real;86-real;89-real;90-real" CACHE option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner") option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp") -if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") +if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") message(STATUS "Targeting libc++") set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS}) -else() +else () message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}") -endif() +endif () # Add dependencies include(cmake/fmt.cmake) @@ -42,15 +42,17 @@ fetchcontent_declare( fetchcontent_makeavailable(llama) -add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp) -target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11) -target_link_libraries(tgi_llama_cpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common) +add_library(tgi_llamacpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp) +target_compile_features(tgi_llamacpp_backend_impl PRIVATE cxx_std_11) +target_link_libraries(tgi_llamacpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common) + +install(TARGETS tgi_llamacpp_backend_impl spdlog llama common) if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER}) message(STATUS "Building llama.cpp offline runner") - add_executable(tgi_llama_cpp_offline_runner offline/main.cpp) + add_executable(tgi_llama_cppoffline_runner offline/main.cpp) - target_link_libraries(tgi_llama_cpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common) + target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common) endif () diff --git a/backends/llamacpp/Cargo.toml b/backends/llamacpp/Cargo.toml index fdd980c3..4a14dcdf 100644 --- a/backends/llamacpp/Cargo.toml +++ b/backends/llamacpp/Cargo.toml @@ -6,6 +6,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "0.1" clap = { version = "4.5.19", features = ["derive"] } cxx = "1.0" hf-hub = { workspace = true } @@ -18,6 +19,11 @@ thiserror = "1.0.64" tokio = "1.40.0" tokio-stream = "0.1.16" tokenizers = { workspace = true } +tracing = "0.1" +tracing-opentelemetry = "0.27.0" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +utoipa = { version = "5.1.2", features = ["axum_extras"] } +log = "0.4.22" [build-dependencies] cmake = "0.1" diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs index d84e517f..642a9665 100644 --- a/backends/llamacpp/build.rs +++ b/backends/llamacpp/build.rs @@ -1,12 +1,14 @@ use cxx_build::CFG; use std::env; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real"; -const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl"; -const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llamacpp_backend_impl"; +const CMAKE_LLAMA_CPP_FFI_TARGET: &str = "tgi_llamacpp_backend"; const MPI_REQUIRED_VERSION: &str = "4.1"; +const BACKEND_DEPS: [&str; 2] = [CMAKE_LLAMA_CPP_TARGET, CMAKE_LLAMA_CPP_FFI_TARGET]; + macro_rules! probe { ($name: expr, $version: expr) => { if let Err(_) = pkg_config::probe_library($name) { @@ -16,11 +18,12 @@ macro_rules! probe { }; } -fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf { - let install_path = env::var("CMAKE_INSTALL_PREFIX") - .map(|val| PathBuf::from(val)) - .unwrap_or(out_dir.join("dist")); - +fn build_backend( + is_debug: bool, + opt_level: &str, + out_dir: &Path, + install_path: &PathBuf, +) -> PathBuf { let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF"); let cuda_archs = option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS); @@ -38,41 +41,28 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf .define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs) .build(); - // Additional transitive CMake dependencies - let deps_folder = out_dir.join("build").join("_deps"); - for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { - let dep_name = match is_debug { - true => format!("{}d", dependency), - false => String::from(dependency), - }; - let dep_path = deps_folder.join(format!("{}-build", dependency)); - println!("cargo:rustc-link-search={}", dep_path.display()); - println!("cargo:rustc-link-lib=static={}", dep_name); - } + let lib_path = install_path.join("lib64"); + println!("cargo:rustc-link-search=native={}", lib_path.display()); let deps_folder = out_dir.join("build").join("_deps"); deps_folder } -fn build_ffi_layer(deps_folder: &PathBuf) { - println!("cargo:warning={}", &deps_folder.display()); +fn build_ffi_layer(deps_folder: &Path, install_prefix: &Path) { + println!("cargo:warning={}", deps_folder.display()); CFG.include_prefix = "backends/llamacpp"; cxx_build::bridge("src/lib.rs") .static_flag(true) .std("c++23") - .include(deps_folder.join("fmt-src").join("include")) - .include(deps_folder.join("spdlog-src").join("include")) - .include(deps_folder.join("llama-src").join("common")) - .include(deps_folder.join("llama-src").join("ggml").join("include")) - .include(deps_folder.join("llama-src").join("include")) - .include("csrc/backend.hpp") - .file("csrc/ffi.cpp") - .compile(CMAKE_LLAMA_CPP_TARGET); - - println!("cargo:rerun-if-changed=CMakeLists.txt"); - println!("cargo:rerun-if-changed=csrc/backend.hpp"); - println!("cargo:rerun-if-changed=csrc/backend.cpp"); - println!("cargo:rerun-if-changed=csrc/ffi.hpp"); + .include(deps_folder.join("spdlog-src").join("include")) // Why spdlog doesnt install headers? + // .include(deps_folder.join("fmt-src").join("include")) // Why spdlog doesnt install headers? + // .include(deps_folder.join("llama-src").join("include")) // Why spdlog doesnt install headers? + .include(deps_folder.join("llama-src").join("ggml").join("include")) // Why spdlog doesnt install headers? + .include(deps_folder.join("llama-src").join("common").join("include")) // Why spdlog doesnt install headers? + .include(install_prefix.join("include")) + .include("csrc") + .file("csrc/ffi.hpp") + .compile(CMAKE_LLAMA_CPP_FFI_TARGET); } fn main() { @@ -84,17 +74,35 @@ fn main() { _ => (false, "3"), }; + let install_path = env::var("CMAKE_INSTALL_PREFIX") + .map(|val| PathBuf::from(val)) + .unwrap_or(out_dir.join("dist")); + // Build the backend - let deps_folder = build_backend(is_debug, opt_level, &out_dir); + let deps_path = build_backend(is_debug, opt_level, out_dir.as_path(), &install_path); // Build the FFI layer calling the backend above - build_ffi_layer(&deps_folder); + build_ffi_layer(&deps_path, &install_path); // Emit linkage search path probe!("ompi", MPI_REQUIRED_VERSION); // Backend - // BACKEND_DEPS.iter().for_each(|name| { - // println!("cargo:rustc-link-lib=static={}", name); - // }); + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); + + // Linkage info + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-link-lib=static=fmtd"); + println!("cargo:rustc-link-lib=static=spdlogd"); + println!("cargo:rustc-link-lib=static=common"); + println!("cargo:rustc-link-lib=dylib=ggml"); + println!("cargo:rustc-link-lib=dylib=llama"); + + // Rerun if one of these file change + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=csrc/backend.hpp"); + println!("cargo:rerun-if-changed=csrc/backend.cpp"); + println!("cargo:rerun-if-changed=csrc/ffi.hpp"); } diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index c8806957..ba4a02d5 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -14,33 +14,35 @@ #include "backend.hpp" -namespace huggingface::tgi::backends::llama { - std::expected, TgiLlamaCppBackendError> - CreateLlamaCppBackend(const std::filesystem::path& modelPath) { +namespace huggingface::tgi::backends::llamacpp { + [[nodiscard]] + std::expected, TgiLlamaCppBackendError> + TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath) noexcept { SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath); llama_backend_init(); llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); // Load the model - if(!exists(modelPath)) { + if (!exists(modelPath)) { return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST); } auto params = llama_model_default_params(); - auto* model = llama_load_model_from_file(modelPath.c_str(), params); - auto* context = llama_new_context_with_model(model, { - .n_batch = 1, - .n_threads = 16, - .attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL, - .flash_attn = false, + auto *model = llama_load_model_from_file(modelPath.c_str(), params); + auto *context = llama_new_context_with_model(model, { + .n_batch = 1, + .n_threads = 16, + .attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL, + .flash_attn = false, }); - return std::make_unique(model, context); + return std::make_pair(model, context); } - huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx) - : model(model), ctx(ctx) { + huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, + llama_context *const ctx) + : model(model), ctx(ctx) { #ifndef NDEBUG char modelName[256]; llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName)); @@ -48,13 +50,13 @@ namespace huggingface::tgi::backends::llama { #endif } - huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() { + huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::~TgiLlamaCppBackend() { if (ctx) { SPDLOG_DEBUG("Freeing llama.cpp context"); llama_free(ctx); } - if(model) { + if (model) { SPDLOG_DEBUG("Freeing llama.cpp model"); llama_free_model(model); } @@ -63,7 +65,8 @@ namespace huggingface::tgi::backends::llama { std::vector TgiLlamaCppBackend::Tokenize(const std::string &text) const { std::vector tokens(llama_n_seq_max(ctx)); - if(auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); nTokens < 0){ + if (auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, + true); nTokens < 0) { tokens.resize(-nTokens); llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); } else { @@ -75,14 +78,15 @@ namespace huggingface::tgi::backends::llama { } std::unique_ptr TgiLlamaCppBackend::GetSamplerFromArgs( - const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty, const uint64_t seed) { + const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty, + const uint64_t seed) { auto *sampler = llama_sampler_chain_init({.no_perf = false}); // Penalties llama_sampler_chain_add(sampler, llama_sampler_init_penalties( llama_n_vocab(model), llama_token_eos(model), - llama_token_nl (model), + llama_token_nl(model), 0.0f, repetitionPenalty, frequencyPenalty, @@ -92,15 +96,16 @@ namespace huggingface::tgi::backends::llama { )); llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast(topK))); - if(0 < topP && topP < 1) { + if (0 < topP && topP < 1) { llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1)); } llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed)); - return std::make_unique(sampler); + return std::make_unique(sampler); } - std::expected, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate( + std::expected, TgiLlamaCppBackendError> + huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::Generate( std::span tokens, const uint32_t topK, const float_t topP, @@ -108,7 +113,7 @@ namespace huggingface::tgi::backends::llama { const float_t repetitionPenalty, const uint32_t maxNewTokens, const uint64_t seed - ) { + ) { SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size()); // Allocate generation result @@ -120,7 +125,7 @@ namespace huggingface::tgi::backends::llama { auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed); // Decode - for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) { + for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) { #ifndef NDEBUG const auto start = std::chrono::steady_clock::now(); const auto status = llama_decode(ctx, batch); diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 7075642a..7fa47e84 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -9,12 +9,14 @@ #include #include #include +#include + #include #define LLAMA_SUCCESS(x) x == 0 -namespace huggingface::tgi::backends::llama { - enum TgiLlamaCppBackendError: uint8_t { +namespace huggingface::tgi::backends::llamacpp { + enum TgiLlamaCppBackendError : uint8_t { MODEL_FILE_DOESNT_EXIST = 1 }; @@ -22,8 +24,8 @@ namespace huggingface::tgi::backends::llama { using TokenId = llama_token; private: - llama_model* model; - llama_context* ctx; + llama_model *model; + llama_context *ctx; /** * @@ -35,7 +37,15 @@ namespace huggingface::tgi::backends::llama { uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed); public: + /** + * + * @return + */ + static std::expected, TgiLlamaCppBackendError> + FromGGUF(const std::filesystem::path &) noexcept; + TgiLlamaCppBackend(llama_model *model, llama_context *ctx); + ~TgiLlamaCppBackend(); /** @@ -44,7 +54,7 @@ namespace huggingface::tgi::backends::llama { * @return */ [[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]] - std::vector Tokenize(const std::string& text) const; + std::vector Tokenize(const std::string &text) const; /** * @@ -71,7 +81,7 @@ namespace huggingface::tgi::backends::llama { [[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]] std::expected, TgiLlamaCppBackendError> - CreateLlamaCppBackend(const std::filesystem::path& root); + CreateLlamaCppBackend(const std::filesystem::path &root); } #endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index e924316e..82f3f296 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -5,14 +5,44 @@ #ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP #define TGI_LLAMA_CPP_BACKEND_FFI_HPP +#include +#include +#include + +#include #include "backend.hpp" -//#include "backends/llamacpp/src/lib.rs.h" + +namespace huggingface::tgi::backends::llamacpp::impl { + class LlamaCppBackendImpl; +} -namespace huggingface::tgi::backends::llama { - class LlamaCppBackendImpl { +#include "backends/llamacpp/src/lib.rs.h" + + +namespace huggingface::tgi::backends::llamacpp::impl { + + class LlamaCppBackendException : std::exception { }; + + class LlamaCppBackendImpl { + private: + TgiLlamaCppBackend _inner; + + public: + LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {} + }; + + std::unique_ptr CreateLlamaCppBackendImpl(rust::Str modelPath) { + const auto cxxPath = std::string_view(modelPath); + if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath)); maybe.has_value()) { + auto [model, context] = *maybe; + return std::make_unique(model, context); + } else { + throw LlamaCppBackendException(); + } + } } diff --git a/backends/llamacpp/offline/main.cpp b/backends/llamacpp/offline/main.cpp index c2ae05c7..56eb88c5 100644 --- a/backends/llamacpp/offline/main.cpp +++ b/backends/llamacpp/offline/main.cpp @@ -10,6 +10,8 @@ #include #include "../csrc/backend.hpp" +using namespace huggingface::tgi::backends::llamacpp; + int main(int argc, char** argv) { if (argc < 2) { fmt::print("No model folder provider"); @@ -21,7 +23,7 @@ int main(int argc, char** argv) { const auto prompt = "My name is Morgan"; const auto modelPath = absolute(std::filesystem::path(argv[1])); - if (auto maybeBackend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) { + if (auto maybeBackend = CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) { // Retrieve the backend const auto& backend = *maybeBackend; @@ -38,7 +40,7 @@ int main(int argc, char** argv) { } else { switch (maybeBackend.error()) { - case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST: + case TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST: fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath); return maybeBackend.error(); } diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 89daeee3..7b22e4a2 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,31 +1,66 @@ use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl}; +use async_trait::async_trait; use cxx::UniquePtr; -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::sync::Arc; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; +use thiserror::Error; +use tokio::task::spawn_blocking; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::info; -pub struct TgiLlamaCppBakend { - backend: UniquePtr, +unsafe impl Send for LlamaCppBackendImpl {} + +#[derive(Debug, Error)] +pub enum LlamaCppBackendError { + #[error("Provided GGUF model path {0} doesn't exist")] + ModelFileDoesntExist(String), + + #[error("Failed to initialize model from GGUF file {0}: {1}")] + ModelInitializationFailed(PathBuf, String), } -impl TgiLlamaCppBakend { - pub fn new>(model_path: P) -> Result { - Ok(Self { - backend: create_llamacpp_backend(model_path.as_ref().to_str().unwrap()), - }) +pub struct LlamaCppBackend {} + +impl LlamaCppBackend { + pub fn new + Send>(model_path: P) -> Result { + let path = Arc::new(model_path.as_ref()); + if !path.exists() { + return Err(LlamaCppBackendError::ModelFileDoesntExist( + path.display().to_string(), + )); + } + + let mut backend = create_llamacpp_backend(path.to_str().unwrap()).map_err(|err| { + LlamaCppBackendError::ModelInitializationFailed( + path.to_path_buf(), + err.what().to_string(), + ) + })?; + + info!( + "Successfully initialized llama.cpp backend from {}", + path.display() + ); + + spawn_blocking(move || scheduler_loop(backend)); + Ok(Self {}) } } -impl Backend for TgiLlamaCppBakend { +async fn scheduler_loop(mut backend: UniquePtr) {} + +#[async_trait] +impl Backend for LlamaCppBackend { fn schedule( &self, - request: ValidGenerateRequest, + _request: ValidGenerateRequest, ) -> Result>, InferError> { Err(InferError::GenerationError("Not implemented yet".into())) } - async fn health(&self, current_health: bool) -> bool { - todo!() + async fn health(&self, _: bool) -> bool { + true } } diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index d25e3ca0..2bfc3065 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -1,6 +1,6 @@ pub mod backend; -#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")] +#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")] mod ffi { unsafe extern "C++" { include!("backends/llamacpp/csrc/ffi.hpp"); @@ -9,8 +9,8 @@ mod ffi { type LlamaCppBackendImpl; #[rust_name = "create_llamacpp_backend"] - fn CreateLlamaCppBackend( - engine_folder: &str, - ) -> UniquePtr; + fn CreateLlamaCppBackendImpl( + modelPath: &str, + ) -> Result>; } } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 7226473c..7420e16a 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,7 +1,8 @@ use clap::{Parser, Subcommand}; +use std::path::PathBuf; +use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError}; use text_generation_router::{server, usage_stats}; use thiserror::Error; -use text_generation_router::server::ApiDoc; /// App Configuration #[derive(Parser, Debug)] @@ -38,6 +39,8 @@ struct Args { port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, + #[clap(long, env, help = "Path to GGUF model file(s) to load")] + gguf_path: PathBuf, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, #[clap(long, env)] @@ -98,6 +101,7 @@ async fn main() -> Result<(), RouterError> { hostname, port, master_shard_uds_path, + gguf_path, tokenizer_name, tokenizer_config_path, revision, @@ -116,13 +120,13 @@ async fn main() -> Result<(), RouterError> { usage_stats, } = args; - if let Some(Commands::PrintSchema) = command { - use utoipa::OpenApi; - let api_doc = ApiDoc::openapi(); - let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); - println!("{}", api_doc); - std::process::exit(0); - }; + // if let Some(Commands::PrintSchema) = command { + // use utoipa::OpenApi; + // let api_doc = ApiDoc::openapi(); + // let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + // println!("{}", api_doc); + // std::process::exit(0); + // }; text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args @@ -158,7 +162,7 @@ async fn main() -> Result<(), RouterError> { } } - let backend = LlamaCppBackend::new(); + let backend = LlamaCppBackend::new(gguf_path)?; // Run server server::run( @@ -185,7 +189,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, ) - .await?; + .await?; Ok(()) } @@ -194,9 +198,9 @@ enum RouterError { #[error("Argument validation error: {0}")] ArgumentValidation(String), #[error("Backend failed: {0}")] - Backend(#[from] V3Error), + Backend(#[from] LlamaCppBackendError), #[error("WebServer error: {0}")] WebServer(#[from] server::WebServerError), #[error("Tokio runtime failed to start: {0}")] Tokio(#[from] std::io::Error), -} \ No newline at end of file +}