mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
feat(llamacpp): initial end2end build
This commit is contained in:
parent
7d1f8a2bd6
commit
52d57dca79
@ -6,12 +6,18 @@ set(CMAKE_CXX_STANDARD 20)
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(LLAMA_CPP_TARGET_VERSION "b3837" STRING "Version of llama.cpp to build against")
|
set(LLAMA_CPP_TARGET_VERSION "b3837" STRING "Version of llama.cpp to build against")
|
||||||
|
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
||||||
|
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
||||||
|
|
||||||
# Add dependencies
|
# Add dependencies
|
||||||
include(cmake/fmt.cmake)
|
include(cmake/fmt.cmake)
|
||||||
include(cmake/spdlog.cmake)
|
include(cmake/spdlog.cmake)
|
||||||
|
|
||||||
|
if(${LLAMA_CPP_BUILD_CUDA})
|
||||||
|
message(STATUS "Enabling llama.cpp CUDA support")
|
||||||
|
set(GGML_CUDA ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Download llama.cpp repo at the specific version
|
# Download llama.cpp repo at the specific version
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
llama
|
llama
|
||||||
@ -25,4 +31,12 @@ fetchcontent_makeavailable(llama)
|
|||||||
|
|
||||||
add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
|
add_library(tgi_llama_cpp_backend_impl STATIC csrc/backend.hpp csrc/backend.cpp)
|
||||||
target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11)
|
target_compile_features(tgi_llama_cpp_backend_impl PRIVATE cxx_std_11)
|
||||||
target_link_libraries(tgi_llama_cpp_backend_impl fmt::fmt spdlog::spdlog llama common)
|
target_link_libraries(tgi_llama_cpp_backend_impl PUBLIC fmt::fmt spdlog::spdlog llama common)
|
||||||
|
|
||||||
|
if(${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
|
||||||
|
message(STATUS "Building llama.cpp offline runner")
|
||||||
|
add_executable(tgi_llama_cpp_offline_runner offline/main.cpp)
|
||||||
|
target_link_libraries(tgi_llama_cpp_offline_runner tgi_llama_cpp_backend_impl)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,3 +6,20 @@ authors.workspace = true
|
|||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
clap = { version = "4.5.19", features = ["derive"] }
|
||||||
|
cxx = "1.0"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
|
image = { version = "0.25.1", features = ["default-formats"] }
|
||||||
|
metrics = { workspace = true }
|
||||||
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
serde_json = "1.0.128"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
thiserror = "1.0.64"
|
||||||
|
tokio = "1.40.0"
|
||||||
|
tokio-stream = "0.1.16"
|
||||||
|
tokenizers = { workspace = true }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
cmake = "0.1"
|
||||||
|
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||||
|
pkg-config = "0.3"
|
94
backends/llamacpp/build.rs
Normal file
94
backends/llamacpp/build.rs
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
use cxx_build::CFG;
|
||||||
|
use std::env;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
||||||
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
|
|
||||||
|
macro_rules! probe {
|
||||||
|
($name: expr, $version: expr) => {
|
||||||
|
if let Err(_) = pkg_config::probe_library($name) {
|
||||||
|
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||||
|
.expect(&format!("Failed to locate {}", $name));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf {
|
||||||
|
let install_path = env::var("CMAKE_INSTALL_PREFIX")
|
||||||
|
.map(|val| PathBuf::from(val))
|
||||||
|
.unwrap_or(out_dir.join("dist"));
|
||||||
|
|
||||||
|
let _ = cmake::Config::new(".")
|
||||||
|
.uses_cxx11()
|
||||||
|
.generator("Ninja")
|
||||||
|
.profile(match is_debug {
|
||||||
|
true => "Debug",
|
||||||
|
false => "Release",
|
||||||
|
})
|
||||||
|
.env("OPT_LEVEL", opt_level)
|
||||||
|
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||||
|
// .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||||
|
// .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||||
|
// .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Additional transitive CMake dependencies
|
||||||
|
let deps_folder = out_dir.join("build").join("_deps");
|
||||||
|
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||||
|
let dep_name = match is_debug {
|
||||||
|
true => format!("{}d", dependency),
|
||||||
|
false => String::from(dependency),
|
||||||
|
};
|
||||||
|
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||||
|
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||||
|
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
let deps_folder = out_dir.join("build").join("_deps");
|
||||||
|
deps_folder
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||||
|
println!("cargo:warning={}", &deps_folder.display());
|
||||||
|
CFG.include_prefix = "backends/llamacpp";
|
||||||
|
cxx_build::bridge("src/lib.rs")
|
||||||
|
.static_flag(true)
|
||||||
|
.include(deps_folder.join("fmt-src").join("include"))
|
||||||
|
.include(deps_folder.join("spdlog-src").join("include"))
|
||||||
|
.include(deps_folder.join("llama-src").join("common"))
|
||||||
|
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
||||||
|
.include(deps_folder.join("llama-src").join("include"))
|
||||||
|
.file("csrc/backend.cpp")
|
||||||
|
.std("c++20")
|
||||||
|
.compile(CMAKE_LLAMA_CPP_TARGET);
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Misc variables
|
||||||
|
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
|
let build_profile = env::var("PROFILE").unwrap();
|
||||||
|
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||||
|
"debug" => (true, "0"),
|
||||||
|
_ => (false, "3"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the backend
|
||||||
|
let deps_folder = build_backend(is_debug, opt_level, &out_dir);
|
||||||
|
|
||||||
|
// Build the FFI layer calling the backend above
|
||||||
|
build_ffi_layer(&deps_folder);
|
||||||
|
|
||||||
|
// Emit linkage search path
|
||||||
|
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||||
|
|
||||||
|
// Backend
|
||||||
|
// BACKEND_DEPS.iter().for_each(|name| {
|
||||||
|
// println!("cargo:rustc-link-lib=static={}", name);
|
||||||
|
// });
|
||||||
|
}
|
@ -4,9 +4,10 @@ set(SPDLOG_FMT_EXTERNAL ON)
|
|||||||
|
|
||||||
# Define the level at which SPDLOG_ compilation level is defined
|
# Define the level at which SPDLOG_ compilation level is defined
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
message(STATUS "Verbose logging is enabled in debug build")
|
||||||
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_DEBUG)
|
||||||
else()
|
else()
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_INFO)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
|
@ -46,8 +46,11 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
||||||
: model(model), ctx(ctx), batch() {
|
: model(model), ctx(ctx), batch()
|
||||||
|
{
|
||||||
|
char modelName[128];
|
||||||
|
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||||
}
|
}
|
||||||
|
|
||||||
TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||||
@ -63,4 +66,8 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TgiLlamaCppBackend::schedule() {
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,7 +1,6 @@
|
|||||||
//
|
//
|
||||||
// Created by Morgan Funtowicz on 9/28/2024.
|
// Created by Morgan Funtowicz on 9/28/2024.
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||||
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||||
|
|
||||||
@ -9,7 +8,7 @@
|
|||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llama {
|
||||||
const char* TGI_BACKEND_LLAMA_CPP_NAME = "llama.cpp";
|
// const char* TGI_BACKEND_LLAMA_CPP_NAME = "llama.cpp";
|
||||||
|
|
||||||
|
|
||||||
class TgiLlamaCppBackend {
|
class TgiLlamaCppBackend {
|
||||||
@ -18,8 +17,10 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
llama_context* ctx;
|
llama_context* ctx;
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
public:
|
public:
|
||||||
TgiLlamaCppBackend(llama_model* const model, llama_context* const);
|
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
||||||
~TgiLlamaCppBackend();
|
~TgiLlamaCppBackend();
|
||||||
|
|
||||||
|
void schedule();
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root);
|
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root);
|
||||||
|
22
backends/llamacpp/offline/main.cpp
Normal file
22
backends/llamacpp/offline/main.cpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 10/3/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <string_view>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <fmt/color.h>
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include "../csrc/backend.hpp"
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
if(argc < 2) {
|
||||||
|
fmt::print("No model folder provider");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
spdlog::set_level(spdlog::level::debug);
|
||||||
|
|
||||||
|
const std::string_view model_root = argv[1];
|
||||||
|
auto backend = huggingface::tgi::backends::llama::CreateLlamaCppBackend(model_root);
|
||||||
|
fmt::print(fmt::emphasis::bold | fg(fmt::color::yellow), "Successfully initialized llama.cpp model from {}\n", model_root);
|
||||||
|
}
|
18
backends/llamacpp/src/backend.rs
Normal file
18
backends/llamacpp/src/backend.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
|
||||||
|
pub struct TgiLlamaCppBakend {}
|
||||||
|
|
||||||
|
impl Backend for TgiLlamaCppBakend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
Err(InferError::GenerationError("Not implemented yet".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
11
backends/llamacpp/src/lib.rs
Normal file
11
backends/llamacpp/src/lib.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub mod backend;
|
||||||
|
|
||||||
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")]
|
||||||
|
mod ffi {
|
||||||
|
unsafe extern "C++" {
|
||||||
|
include!("backends/llamacpp/csrc/backend.cpp");
|
||||||
|
|
||||||
|
/// Represent an instance of the llama.cpp backend instance on C++ side
|
||||||
|
type LlamaCppBackendImpl;
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,202 @@
|
|||||||
fn main() {
|
use clap::{Parser, Subcommand};
|
||||||
println!("Hello, world!");
|
use text_generation_router::{server, usage_stats};
|
||||||
|
use thiserror::Error;
|
||||||
|
use text_generation_router::server::ApiDoc;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
command,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
master_shard_uds_path,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
let api_doc = ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
};
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
if max_batch_size == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_batch_size` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let backend = LlamaCppBackend::new();
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Backend failed: {0}")]
|
||||||
|
Backend(#[from] V3Error),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
}
|
}
|
@ -18,6 +18,8 @@ set(CMAKE_CXX_STANDARD 20)
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
|
|
||||||
|
set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "--allow-unsupported-compiler -ccbin=gcc")
|
||||||
|
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||||
|
Loading…
Reference in New Issue
Block a user