mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat(backend): build and link through build.rs
This commit is contained in:
parent
355d8a55b4
commit
e4d803c94e
90
Cargo.lock
generated
90
Cargo.lock
generated
@ -2732,6 +2732,20 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry"
|
||||||
|
version = "0.26.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "570074cc999d1a58184080966e5bd3bf3a9a4af650c3b05047c2621e7405cd17"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"js-sys",
|
||||||
|
"once_cell",
|
||||||
|
"pin-project-lite",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry-otlp"
|
name = "opentelemetry-otlp"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
@ -2849,6 +2863,24 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry_sdk"
|
||||||
|
version = "0.26.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2c627d9f4c9cdc1f21a29ee4bfbd6028fcb8bcf2a857b43f3abdf72c9c862f3"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"futures-channel",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-util",
|
||||||
|
"glob",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry 0.26.0",
|
||||||
|
"percent-encoding",
|
||||||
|
"rand",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "option-ext"
|
name = "option-ext"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
@ -4187,12 +4219,14 @@ dependencies = [
|
|||||||
name = "text-generation-backend-llamacpp"
|
name = "text-generation-backend-llamacpp"
|
||||||
version = "2.4.1-dev0"
|
version = "2.4.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
"clap 4.5.20",
|
"clap 4.5.20",
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"image",
|
"image",
|
||||||
|
"log",
|
||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
@ -4202,6 +4236,10 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tracing",
|
||||||
|
"tracing-opentelemetry 0.27.0",
|
||||||
|
"tracing-subscriber",
|
||||||
|
"utoipa 5.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4330,7 +4368,7 @@ dependencies = [
|
|||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"ureq",
|
"ureq",
|
||||||
"utoipa",
|
"utoipa 4.2.3",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
"uuid",
|
"uuid",
|
||||||
"vergen",
|
"vergen",
|
||||||
@ -4381,7 +4419,7 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"utoipa",
|
"utoipa 4.2.3",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -4432,7 +4470,7 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"utoipa",
|
"utoipa 4.2.3",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -4946,6 +4984,24 @@ dependencies = [
|
|||||||
"web-time 1.1.0",
|
"web-time 1.1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-opentelemetry"
|
||||||
|
version = "0.27.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dc58af5d3f6c5811462cabb3289aec0093f7338e367e5a33d28c0433b3c7360b"
|
||||||
|
dependencies = [
|
||||||
|
"js-sys",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry 0.26.0",
|
||||||
|
"opentelemetry_sdk 0.26.0",
|
||||||
|
"smallvec",
|
||||||
|
"tracing",
|
||||||
|
"tracing-core",
|
||||||
|
"tracing-log 0.2.0",
|
||||||
|
"tracing-subscriber",
|
||||||
|
"web-time 1.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-opentelemetry-instrumentation-sdk"
|
name = "tracing-opentelemetry-instrumentation-sdk"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
@ -5136,7 +5192,19 @@ dependencies = [
|
|||||||
"indexmap 2.6.0",
|
"indexmap 2.6.0",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"utoipa-gen",
|
"utoipa-gen 4.3.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa"
|
||||||
|
version = "5.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5e12e84f0ff45b6818029cd0f67280e453c80132c1b9897df407ecc20b9f7cfd"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap 2.5.0",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa-gen 5.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -5152,6 +5220,18 @@ dependencies = [
|
|||||||
"syn 2.0.85",
|
"syn 2.0.85",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-gen"
|
||||||
|
version = "5.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0dfc694d3a3118d2b9e80d68be83bf1aab7988510916934db83da61c14e7e6b2"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"syn 2.0.79",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa-swagger-ui"
|
name = "utoipa-swagger-ui"
|
||||||
version = "6.0.0"
|
version = "6.0.0"
|
||||||
@ -5164,7 +5244,7 @@ dependencies = [
|
|||||||
"rust-embed",
|
"rust-embed",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"utoipa",
|
"utoipa 4.2.3",
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -11,12 +11,12 @@ set(LLAMA_CPP_TARGET_CUDA_ARCHS "75-real;80-real;86-real;89-real;90-real" CACHE
|
|||||||
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
||||||
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
||||||
|
|
||||||
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||||
message(STATUS "Targeting libc++")
|
message(STATUS "Targeting libc++")
|
||||||
set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS})
|
set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS})
|
||||||
else()
|
else ()
|
||||||
message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}")
|
message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}")
|
||||||
endif()
|
endif ()
|
||||||
|
|
||||||
# Add dependencies
|
# Add dependencies
|
||||||
include(cmake/fmt.cmake)
|
include(cmake/fmt.cmake)
|
||||||
@ -42,15 +42,17 @@ fetchcontent_declare(
|
|||||||
|
|
||||||
fetchcontent_makeavailable(llama)
|
fetchcontent_makeavailable(llama)
|
||||||
|
|
||||||
add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
|
add_library(tgi_llamacpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
|
||||||
target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11)
|
target_compile_features(tgi_llamacpp_backend_impl PRIVATE cxx_std_11)
|
||||||
target_link_libraries(tgi_llama_cpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common)
|
target_link_libraries(tgi_llamacpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common)
|
||||||
|
|
||||||
|
install(TARGETS tgi_llamacpp_backend_impl spdlog llama common)
|
||||||
|
|
||||||
if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
|
if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
|
||||||
message(STATUS "Building llama.cpp offline runner")
|
message(STATUS "Building llama.cpp offline runner")
|
||||||
add_executable(tgi_llama_cpp_offline_runner offline/main.cpp)
|
add_executable(tgi_llama_cppoffline_runner offline/main.cpp)
|
||||||
|
|
||||||
target_link_libraries(tgi_llama_cpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common)
|
target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ authors.workspace = true
|
|||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "0.1"
|
||||||
clap = { version = "4.5.19", features = ["derive"] }
|
clap = { version = "4.5.19", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
@ -18,6 +19,11 @@ thiserror = "1.0.64"
|
|||||||
tokio = "1.40.0"
|
tokio = "1.40.0"
|
||||||
tokio-stream = "0.1.16"
|
tokio-stream = "0.1.16"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-opentelemetry = "0.27.0"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
|
utoipa = { version = "5.1.2", features = ["axum_extras"] }
|
||||||
|
log = "0.4.22"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
use cxx_build::CFG;
|
use cxx_build::CFG;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
|
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
|
||||||
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llamacpp_backend_impl";
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
const CMAKE_LLAMA_CPP_FFI_TARGET: &str = "tgi_llamacpp_backend";
|
||||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
|
|
||||||
|
const BACKEND_DEPS: [&str; 2] = [CMAKE_LLAMA_CPP_TARGET, CMAKE_LLAMA_CPP_FFI_TARGET];
|
||||||
|
|
||||||
macro_rules! probe {
|
macro_rules! probe {
|
||||||
($name: expr, $version: expr) => {
|
($name: expr, $version: expr) => {
|
||||||
if let Err(_) = pkg_config::probe_library($name) {
|
if let Err(_) = pkg_config::probe_library($name) {
|
||||||
@ -16,11 +18,12 @@ macro_rules! probe {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf {
|
fn build_backend(
|
||||||
let install_path = env::var("CMAKE_INSTALL_PREFIX")
|
is_debug: bool,
|
||||||
.map(|val| PathBuf::from(val))
|
opt_level: &str,
|
||||||
.unwrap_or(out_dir.join("dist"));
|
out_dir: &Path,
|
||||||
|
install_path: &PathBuf,
|
||||||
|
) -> PathBuf {
|
||||||
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
|
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
|
||||||
let cuda_archs =
|
let cuda_archs =
|
||||||
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
|
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
|
||||||
@ -38,41 +41,28 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
|
|||||||
.define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
|
.define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// Additional transitive CMake dependencies
|
let lib_path = install_path.join("lib64");
|
||||||
let deps_folder = out_dir.join("build").join("_deps");
|
println!("cargo:rustc-link-search=native={}", lib_path.display());
|
||||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
|
||||||
let dep_name = match is_debug {
|
|
||||||
true => format!("{}d", dependency),
|
|
||||||
false => String::from(dependency),
|
|
||||||
};
|
|
||||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
|
||||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
|
||||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
let deps_folder = out_dir.join("build").join("_deps");
|
let deps_folder = out_dir.join("build").join("_deps");
|
||||||
deps_folder
|
deps_folder
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
fn build_ffi_layer(deps_folder: &Path, install_prefix: &Path) {
|
||||||
println!("cargo:warning={}", &deps_folder.display());
|
println!("cargo:warning={}", deps_folder.display());
|
||||||
CFG.include_prefix = "backends/llamacpp";
|
CFG.include_prefix = "backends/llamacpp";
|
||||||
cxx_build::bridge("src/lib.rs")
|
cxx_build::bridge("src/lib.rs")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
.std("c++23")
|
.std("c++23")
|
||||||
.include(deps_folder.join("fmt-src").join("include"))
|
.include(deps_folder.join("spdlog-src").join("include")) // Why spdlog doesnt install headers?
|
||||||
.include(deps_folder.join("spdlog-src").join("include"))
|
// .include(deps_folder.join("fmt-src").join("include")) // Why spdlog doesnt install headers?
|
||||||
.include(deps_folder.join("llama-src").join("common"))
|
// .include(deps_folder.join("llama-src").join("include")) // Why spdlog doesnt install headers?
|
||||||
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
.include(deps_folder.join("llama-src").join("ggml").join("include")) // Why spdlog doesnt install headers?
|
||||||
.include(deps_folder.join("llama-src").join("include"))
|
.include(deps_folder.join("llama-src").join("common").join("include")) // Why spdlog doesnt install headers?
|
||||||
.include("csrc/backend.hpp")
|
.include(install_prefix.join("include"))
|
||||||
.file("csrc/ffi.cpp")
|
.include("csrc")
|
||||||
.compile(CMAKE_LLAMA_CPP_TARGET);
|
.file("csrc/ffi.hpp")
|
||||||
|
.compile(CMAKE_LLAMA_CPP_FFI_TARGET);
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
|
||||||
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
|
||||||
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
|
||||||
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
@ -84,17 +74,35 @@ fn main() {
|
|||||||
_ => (false, "3"),
|
_ => (false, "3"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let install_path = env::var("CMAKE_INSTALL_PREFIX")
|
||||||
|
.map(|val| PathBuf::from(val))
|
||||||
|
.unwrap_or(out_dir.join("dist"));
|
||||||
|
|
||||||
// Build the backend
|
// Build the backend
|
||||||
let deps_folder = build_backend(is_debug, opt_level, &out_dir);
|
let deps_path = build_backend(is_debug, opt_level, out_dir.as_path(), &install_path);
|
||||||
|
|
||||||
// Build the FFI layer calling the backend above
|
// Build the FFI layer calling the backend above
|
||||||
build_ffi_layer(&deps_folder);
|
build_ffi_layer(&deps_path, &install_path);
|
||||||
|
|
||||||
// Emit linkage search path
|
// Emit linkage search path
|
||||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||||
|
|
||||||
// Backend
|
// Backend
|
||||||
// BACKEND_DEPS.iter().for_each(|name| {
|
BACKEND_DEPS.iter().for_each(|name| {
|
||||||
// println!("cargo:rustc-link-lib=static={}", name);
|
println!("cargo:rustc-link-lib=static={}", name);
|
||||||
// });
|
});
|
||||||
|
|
||||||
|
// Linkage info
|
||||||
|
println!("cargo:rustc-link-search=native={}", out_dir.display());
|
||||||
|
println!("cargo:rustc-link-lib=static=fmtd");
|
||||||
|
println!("cargo:rustc-link-lib=static=spdlogd");
|
||||||
|
println!("cargo:rustc-link-lib=static=common");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=ggml");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=llama");
|
||||||
|
|
||||||
|
// Rerun if one of these file change
|
||||||
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
|
||||||
}
|
}
|
||||||
|
@ -14,33 +14,35 @@
|
|||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
[[nodiscard]]
|
||||||
CreateLlamaCppBackend(const std::filesystem::path& modelPath) {
|
std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
|
||||||
|
TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath) noexcept {
|
||||||
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
|
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||||
|
|
||||||
// Load the model
|
// Load the model
|
||||||
if(!exists(modelPath)) {
|
if (!exists(modelPath)) {
|
||||||
return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST);
|
return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto params = llama_model_default_params();
|
auto params = llama_model_default_params();
|
||||||
auto* model = llama_load_model_from_file(modelPath.c_str(), params);
|
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
|
||||||
auto* context = llama_new_context_with_model(model, {
|
auto *context = llama_new_context_with_model(model, {
|
||||||
.n_batch = 1,
|
.n_batch = 1,
|
||||||
.n_threads = 16,
|
.n_threads = 16,
|
||||||
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
|
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
|
||||||
.flash_attn = false,
|
.flash_attn = false,
|
||||||
});
|
});
|
||||||
|
|
||||||
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
|
return std::make_pair(model, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
|
||||||
: model(model), ctx(ctx) {
|
llama_context *const ctx)
|
||||||
|
: model(model), ctx(ctx) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
char modelName[256];
|
char modelName[256];
|
||||||
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
|
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
|
||||||
@ -48,13 +50,13 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp context");
|
SPDLOG_DEBUG("Freeing llama.cpp context");
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(model) {
|
if (model) {
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp model");
|
SPDLOG_DEBUG("Freeing llama.cpp model");
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
}
|
}
|
||||||
@ -63,7 +65,8 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const {
|
std::vector<TgiLlamaCppBackend::TokenId> TgiLlamaCppBackend::Tokenize(const std::string &text) const {
|
||||||
std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx));
|
std::vector<TgiLlamaCppBackend::TokenId> tokens(llama_n_seq_max(ctx));
|
||||||
|
|
||||||
if(auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true); nTokens < 0){
|
if (auto nTokens = llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true,
|
||||||
|
true); nTokens < 0) {
|
||||||
tokens.resize(-nTokens);
|
tokens.resize(-nTokens);
|
||||||
llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true);
|
llama_tokenize(model, text.c_str(), text.length(), tokens.data(), tokens.capacity(), true, true);
|
||||||
} else {
|
} else {
|
||||||
@ -75,14 +78,15 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs(
|
std::unique_ptr<llama_sampler *> TgiLlamaCppBackend::GetSamplerFromArgs(
|
||||||
const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty, const uint64_t seed) {
|
const uint32_t topK, const float_t topP, const float_t frequencyPenalty, const float_t repetitionPenalty,
|
||||||
|
const uint64_t seed) {
|
||||||
auto *sampler = llama_sampler_chain_init({.no_perf = false});
|
auto *sampler = llama_sampler_chain_init({.no_perf = false});
|
||||||
|
|
||||||
// Penalties
|
// Penalties
|
||||||
llama_sampler_chain_add(sampler, llama_sampler_init_penalties(
|
llama_sampler_chain_add(sampler, llama_sampler_init_penalties(
|
||||||
llama_n_vocab(model),
|
llama_n_vocab(model),
|
||||||
llama_token_eos(model),
|
llama_token_eos(model),
|
||||||
llama_token_nl (model),
|
llama_token_nl(model),
|
||||||
0.0f,
|
0.0f,
|
||||||
repetitionPenalty,
|
repetitionPenalty,
|
||||||
frequencyPenalty,
|
frequencyPenalty,
|
||||||
@ -92,15 +96,16 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
));
|
));
|
||||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
|
||||||
|
|
||||||
if(0 < topP && topP < 1) {
|
if (0 < topP && topP < 1) {
|
||||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1));
|
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed));
|
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed));
|
||||||
return std::make_unique<llama_sampler*>(sampler);
|
return std::make_unique<llama_sampler *>(sampler);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError>
|
||||||
|
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::Generate(
|
||||||
std::span<const TokenId> tokens,
|
std::span<const TokenId> tokens,
|
||||||
const uint32_t topK,
|
const uint32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
@ -108,7 +113,7 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
const float_t repetitionPenalty,
|
const float_t repetitionPenalty,
|
||||||
const uint32_t maxNewTokens,
|
const uint32_t maxNewTokens,
|
||||||
const uint64_t seed
|
const uint64_t seed
|
||||||
) {
|
) {
|
||||||
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
||||||
|
|
||||||
// Allocate generation result
|
// Allocate generation result
|
||||||
@ -120,7 +125,7 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
|
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
|
||||||
|
|
||||||
// Decode
|
// Decode
|
||||||
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
const auto start = std::chrono::steady_clock::now();
|
const auto start = std::chrono::steady_clock::now();
|
||||||
const auto status = llama_decode(ctx, batch);
|
const auto status = llama_decode(ctx, batch);
|
||||||
|
@ -9,12 +9,14 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <span>
|
#include <span>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
|
||||||
#define LLAMA_SUCCESS(x) x == 0
|
#define LLAMA_SUCCESS(x) x == 0
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
enum TgiLlamaCppBackendError: uint8_t {
|
enum TgiLlamaCppBackendError : uint8_t {
|
||||||
MODEL_FILE_DOESNT_EXIST = 1
|
MODEL_FILE_DOESNT_EXIST = 1
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -22,8 +24,8 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
using TokenId = llama_token;
|
using TokenId = llama_token;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llama_model* model;
|
llama_model *model;
|
||||||
llama_context* ctx;
|
llama_context *ctx;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
@ -35,7 +37,15 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed);
|
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
|
||||||
|
FromGGUF(const std::filesystem::path &) noexcept;
|
||||||
|
|
||||||
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
||||||
|
|
||||||
~TgiLlamaCppBackend();
|
~TgiLlamaCppBackend();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -44,7 +54,7 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
|
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
|
||||||
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string &text) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
@ -71,7 +81,7 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
|
|
||||||
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
|
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
|
||||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
||||||
CreateLlamaCppBackend(const std::filesystem::path& root);
|
CreateLlamaCppBackend(const std::filesystem::path &root);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||||
|
@ -5,14 +5,44 @@
|
|||||||
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
||||||
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
//#include "backends/llamacpp/src/lib.rs.h"
|
|
||||||
|
namespace huggingface::tgi::backends::llamacpp::impl {
|
||||||
|
class LlamaCppBackendImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
#include "backends/llamacpp/src/lib.rs.h"
|
||||||
class LlamaCppBackendImpl {
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::llamacpp::impl {
|
||||||
|
|
||||||
|
class LlamaCppBackendException : std::exception {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LlamaCppBackendImpl {
|
||||||
|
private:
|
||||||
|
TgiLlamaCppBackend _inner;
|
||||||
|
|
||||||
|
public:
|
||||||
|
LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath) {
|
||||||
|
const auto cxxPath = std::string_view(modelPath);
|
||||||
|
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath)); maybe.has_value()) {
|
||||||
|
auto [model, context] = *maybe;
|
||||||
|
return std::make_unique<LlamaCppBackendImpl>(model, context);
|
||||||
|
} else {
|
||||||
|
throw LlamaCppBackendException();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,8 @@
|
|||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
#include "../csrc/backend.hpp"
|
#include "../csrc/backend.hpp"
|
||||||
|
|
||||||
|
using namespace huggingface::tgi::backends::llamacpp;
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
if (argc < 2) {
|
if (argc < 2) {
|
||||||
fmt::print("No model folder provider");
|
fmt::print("No model folder provider");
|
||||||
@ -21,7 +23,7 @@ int main(int argc, char** argv) {
|
|||||||
const auto prompt = "My name is Morgan";
|
const auto prompt = "My name is Morgan";
|
||||||
|
|
||||||
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
||||||
if (auto maybeBackend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) {
|
if (auto maybeBackend = CreateLlamaCppBackend(modelPath); maybeBackend.has_value()) {
|
||||||
// Retrieve the backend
|
// Retrieve the backend
|
||||||
const auto& backend = *maybeBackend;
|
const auto& backend = *maybeBackend;
|
||||||
|
|
||||||
@ -38,7 +40,7 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
switch (maybeBackend.error()) {
|
switch (maybeBackend.error()) {
|
||||||
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
case TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
||||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
|
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
|
||||||
return maybeBackend.error();
|
return maybeBackend.error();
|
||||||
}
|
}
|
||||||
|
@ -1,31 +1,66 @@
|
|||||||
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
|
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
|
||||||
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::task::spawn_blocking;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
pub struct TgiLlamaCppBakend {
|
unsafe impl Send for LlamaCppBackendImpl {}
|
||||||
backend: UniquePtr<LlamaCppBackendImpl>,
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum LlamaCppBackendError {
|
||||||
|
#[error("Provided GGUF model path {0} doesn't exist")]
|
||||||
|
ModelFileDoesntExist(String),
|
||||||
|
|
||||||
|
#[error("Failed to initialize model from GGUF file {0}: {1}")]
|
||||||
|
ModelInitializationFailed(PathBuf, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TgiLlamaCppBakend {
|
pub struct LlamaCppBackend {}
|
||||||
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, ()> {
|
|
||||||
Ok(Self {
|
impl LlamaCppBackend {
|
||||||
backend: create_llamacpp_backend(model_path.as_ref().to_str().unwrap()),
|
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
||||||
})
|
let path = Arc::new(model_path.as_ref());
|
||||||
|
if !path.exists() {
|
||||||
|
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||||
|
path.display().to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut backend = create_llamacpp_backend(path.to_str().unwrap()).map_err(|err| {
|
||||||
|
LlamaCppBackendError::ModelInitializationFailed(
|
||||||
|
path.to_path_buf(),
|
||||||
|
err.what().to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Successfully initialized llama.cpp backend from {}",
|
||||||
|
path.display()
|
||||||
|
);
|
||||||
|
|
||||||
|
spawn_blocking(move || scheduler_loop(backend));
|
||||||
|
Ok(Self {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Backend for TgiLlamaCppBakend {
|
async fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for LlamaCppBackend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
_request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
Err(InferError::GenerationError("Not implemented yet".into()))
|
Err(InferError::GenerationError("Not implemented yet".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool {
|
async fn health(&self, _: bool) -> bool {
|
||||||
todo!()
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/llamacpp/csrc/ffi.hpp");
|
include!("backends/llamacpp/csrc/ffi.hpp");
|
||||||
@ -9,8 +9,8 @@ mod ffi {
|
|||||||
type LlamaCppBackendImpl;
|
type LlamaCppBackendImpl;
|
||||||
|
|
||||||
#[rust_name = "create_llamacpp_backend"]
|
#[rust_name = "create_llamacpp_backend"]
|
||||||
fn CreateLlamaCppBackend(
|
fn CreateLlamaCppBackendImpl(
|
||||||
engine_folder: &str,
|
modelPath: &str,
|
||||||
) -> UniquePtr<LlamaCppBackendImpl>;
|
) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
|
||||||
use text_generation_router::{server, usage_stats};
|
use text_generation_router::{server, usage_stats};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use text_generation_router::server::ApiDoc;
|
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -38,6 +39,8 @@ struct Args {
|
|||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
|
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
||||||
|
gguf_path: PathBuf,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -98,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
gguf_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
@ -116,13 +120,13 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
usage_stats,
|
usage_stats,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if let Some(Commands::PrintSchema) = command {
|
// if let Some(Commands::PrintSchema) = command {
|
||||||
use utoipa::OpenApi;
|
// use utoipa::OpenApi;
|
||||||
let api_doc = ApiDoc::openapi();
|
// let api_doc = ApiDoc::openapi();
|
||||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
// let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
println!("{}", api_doc);
|
// println!("{}", api_doc);
|
||||||
std::process::exit(0);
|
// std::process::exit(0);
|
||||||
};
|
// };
|
||||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
@ -158,7 +162,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let backend = LlamaCppBackend::new();
|
let backend = LlamaCppBackend::new(gguf_path)?;
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
@ -185,7 +189,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,9 +198,9 @@ enum RouterError {
|
|||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
ArgumentValidation(String),
|
ArgumentValidation(String),
|
||||||
#[error("Backend failed: {0}")]
|
#[error("Backend failed: {0}")]
|
||||||
Backend(#[from] V3Error),
|
Backend(#[from] LlamaCppBackendError),
|
||||||
#[error("WebServer error: {0}")]
|
#[error("WebServer error: {0}")]
|
||||||
WebServer(#[from] server::WebServerError),
|
WebServer(#[from] server::WebServerError),
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
Tokio(#[from] std::io::Error),
|
Tokio(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user