mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat(llamacpp): initial end2end build
This commit is contained in:
parent
7d1f8a2bd6
commit
52d57dca79
@ -6,12 +6,18 @@ set(CMAKE_CXX_STANDARD 20)
|
||||
include(FetchContent)
|
||||
|
||||
set(LLAMA_CPP_TARGET_VERSION "b3837" STRING "Version of llama.cpp to build against")
|
||||
|
||||
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
||||
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
||||
|
||||
# Add dependencies
|
||||
include(cmake/fmt.cmake)
|
||||
include(cmake/spdlog.cmake)
|
||||
|
||||
if(${LLAMA_CPP_BUILD_CUDA})
|
||||
message(STATUS "Enabling llama.cpp CUDA support")
|
||||
set(GGML_CUDA ON)
|
||||
endif()
|
||||
|
||||
# Download llama.cpp repo at the specific version
|
||||
fetchcontent_declare(
|
||||
llama
|
||||
@ -25,4 +31,12 @@ fetchcontent_makeavailable(llama)
|
||||
|
||||
add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
|
||||
target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11)
|
||||
target_link_libraries(tgi_llama_cpp_backend_impl fmt::fmt spdlog::spdlog llama common)
|
||||
target_link_libraries(tgi_llama_cpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common)
|
||||
|
||||
if(${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
|
||||
message(STATUS "Building llama.cpp offline runner")
|
||||
add_executable(tgi_llama_cpp_offline_runner offline/main.cpp)
|
||||
target_link_libraries(tgi_llama_cpp_offline_runner tgi_llama_cpp_backend_impl)
|
||||
endif()
|
||||
|
||||
|
||||
|
@ -6,3 +6,20 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.5.19", features = ["derive"] }
|
||||
cxx = "1.0"
|
||||
hf-hub = { workspace = true }
|
||||
image = { version = "0.25.1", features = ["default-formats"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
serde_json = "1.0.128"
|
||||
text-generation-router = { path = "../../router" }
|
||||
thiserror = "1.0.64"
|
||||
tokio = "1.40.0"
|
||||
tokio-stream = "0.1.16"
|
||||
tokenizers = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||
pkg-config = "0.3"
|
94
backends/llamacpp/build.rs
Normal file
94
backends/llamacpp/build.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use cxx_build::CFG;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||
|
||||
macro_rules! probe {
|
||||
($name: expr, $version: expr) => {
|
||||
if let Err(_) = pkg_config::probe_library($name) {
|
||||
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||
.expect(&format!("Failed to locate {}", $name));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf {
|
||||
let install_path = env::var("CMAKE_INSTALL_PREFIX")
|
||||
.map(|val| PathBuf::from(val))
|
||||
.unwrap_or(out_dir.join("dist"));
|
||||
|
||||
let _ = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.profile(match is_debug {
|
||||
true => "Debug",
|
||||
false => "Release",
|
||||
})
|
||||
.env("OPT_LEVEL", opt_level)
|
||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||
// .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||
// .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||
// .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||
.build();
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
let deps_folder = out_dir.join("build").join("_deps");
|
||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||
let dep_name = match is_debug {
|
||||
true => format!("{}d", dependency),
|
||||
false => String::from(dependency),
|
||||
};
|
||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
let deps_folder = out_dir.join("build").join("_deps");
|
||||
deps_folder
|
||||
}
|
||||
|
||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||
println!("cargo:warning={}", &deps_folder.display());
|
||||
CFG.include_prefix = "backends/llamacpp";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.static_flag(true)
|
||||
.include(deps_folder.join("fmt-src").join("include"))
|
||||
.include(deps_folder.join("spdlog-src").join("include"))
|
||||
.include(deps_folder.join("llama-src").join("common"))
|
||||
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
||||
.include(deps_folder.join("llama-src").join("include"))
|
||||
.file("csrc/backend.cpp")
|
||||
.std("c++20")
|
||||
.compile(CMAKE_LLAMA_CPP_TARGET);
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
||||
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Misc variables
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let build_profile = env::var("PROFILE").unwrap();
|
||||
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||
"debug" => (true, "0"),
|
||||
_ => (false, "3"),
|
||||
};
|
||||
|
||||
// Build the backend
|
||||
let deps_folder = build_backend(is_debug, opt_level, &out_dir);
|
||||
|
||||
// Build the FFI layer calling the backend above
|
||||
build_ffi_layer(&deps_folder);
|
||||
|
||||
// Emit linkage search path
|
||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||
|
||||
// Backend
|
||||
// BACKEND_DEPS.iter().for_each(|name| {
|
||||
// println!("cargo:rustc-link-lib=static={}", name);
|
||||
// });
|
||||
}
|
@ -4,9 +4,10 @@ set(SPDLOG_FMT_EXTERNAL ON)
|
||||
|
||||
# Define the level at which SPDLOG_ compilation level is defined
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||
message(STATUS "Verbose logging is enabled in debug build")
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_DEBUG)
|
||||
else()
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_INFO)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
|
@ -46,8 +46,11 @@ namespace huggingface::tgi::backends::llama {
|
||||
}
|
||||
|
||||
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
||||
: model(model), ctx(ctx), batch() {
|
||||
|
||||
: model(model), ctx(ctx), batch()
|
||||
{
|
||||
char modelName[128];
|
||||
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
}
|
||||
|
||||
TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||
@ -63,4 +66,8 @@ namespace huggingface::tgi::backends::llama {
|
||||
llama_free(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
void TgiLlamaCppBackend::schedule() {
|
||||
std::vector<llama_token> tokens;
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
//
|
||||
// Created by Morgan Funtowicz on 9/28/2024.
|
||||
//
|
||||
|
||||
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||
|
||||
@ -9,7 +8,7 @@
|
||||
#include <llama.h>
|
||||
|
||||
namespace huggingface::tgi::backends::llama {
|
||||
const char* TGI_BACKEND_LLAMA_CPP_NAME = "llama.cpp";
|
||||
// const char* TGI_BACKEND_LLAMA_CPP_NAME = "llama.cpp";
|
||||
|
||||
|
||||
class TgiLlamaCppBackend {
|
||||
@ -18,8 +17,10 @@ namespace huggingface::tgi::backends::llama {
|
||||
llama_context* ctx;
|
||||
llama_batch batch;
|
||||
public:
|
||||
TgiLlamaCppBackend(llama_model* const model, llama_context* const);
|
||||
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
||||
~TgiLlamaCppBackend();
|
||||
|
||||
void schedule();
|
||||
};
|
||||
|
||||
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root);
|
||||
|
22
backends/llamacpp/offline/main.cpp
Normal file
22
backends/llamacpp/offline/main.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
//
|
||||
// Created by mfuntowicz on 10/3/24.
|
||||
//
|
||||
|
||||
#include <string_view>
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/color.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "../csrc/backend.hpp"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if(argc < 2) {
|
||||
fmt::print("No model folder provider");
|
||||
return 1;
|
||||
}
|
||||
|
||||
spdlog::set_level(spdlog::level::debug);
|
||||
|
||||
const std::string_view model_root = argv[1];
|
||||
auto backend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(model_root);
|
||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::yellow), "Successfully initialized llama.cpp model from {}\n", model_root);
|
||||
}
|
18
backends/llamacpp/src/backend.rs
Normal file
18
backends/llamacpp/src/backend.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
pub struct TgiLlamaCppBakend {}
|
||||
|
||||
impl Backend for TgiLlamaCppBakend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
Err(InferError::GenerationError("Not implemented yet".into()))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
todo!()
|
||||
}
|
||||
}
|
11
backends/llamacpp/src/lib.rs
Normal file
11
backends/llamacpp/src/lib.rs
Normal file
@ -0,0 +1,11 @@
|
||||
pub mod backend;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")]
|
||||
mod ffi {
|
||||
unsafe extern "C++" {
|
||||
include!("backends/llamacpp/csrc/backend.cpp");
|
||||
|
||||
/// Represent an instance of the llama.cpp backend instance on C++ side
|
||||
type LlamaCppBackendImpl;
|
||||
}
|
||||
}
|
@ -1,3 +1,202 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use thiserror::Error;
|
||||
use text_generation_router::server::ApiDoc;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let backend = LlamaCppBackend::new();
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
@ -18,6 +18,8 @@ set(CMAKE_CXX_STANDARD 20)
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
|
||||
set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "--allow-unsupported-compiler -ccbin=gcc")
|
||||
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||
|
Loading…
Reference in New Issue
Block a user