diff --git a/Cargo.lock b/Cargo.lock index 27404e41..74246d69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3460,10 +3460,12 @@ version = "2.0.5-dev0" dependencies = [ "async-stream", "async-trait", + "clap", "cmake", "cxx", "cxx-build", "text-generation-router", + "thiserror", "tokio", "tokio-stream", ] diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt index ef8700d9..189984bc 100644 --- a/backends/trtllm/CMakeLists.txt +++ b/backends/trtllm/CMakeLists.txt @@ -29,13 +29,12 @@ include_directories("${trtllm_SOURCE_DIR}/cpp/include") message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") # TGI TRTLLM Backend definition -add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp) - +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp) target_include_directories(tgi_trtllm_backend_impl PRIVATE $ $ ) -target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC spdlog::spdlog) #### Unit Tests #### if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 39369c48..bfe59346 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -12,6 +12,8 @@ cxx = "1.0" text-generation-router = { path = "../../router" } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" +clap = { version = "4.5.4", features = ["derive"] } +thiserror = "1.0.61" [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs index 32d0f943..4a671712 100644 --- a/backends/trtllm/build.rs +++ b/backends/trtllm/build.rs @@ -1,19 +1,63 @@ +use std::env; +use std::path::PathBuf; + use cxx_build::CFG; +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; + fn main() { + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + + // Build the backend implementation through CMake let backend_path = cmake::Config::new(".") .uses_cxx11() .generator("Ninja") + .profile(match build_profile.as_ref() { + "release" => "Release", + _ => "Debug", + }) .build_target("tgi_trtllm_backend_impl") .build(); + // Build the FFI layer calling the backend above CFG.include_prefix = "backends/trtllm"; cxx_build::bridge("src/lib.rs") + .static_flag(true) .file("src/ffi.cpp") .std("c++20") .compile("tgi_trtllm_backend"); + println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=include/backend.h"); println!("cargo:rerun-if-changed=lib/backend.cpp"); - // println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); + + // Additional transitive CMake dependencies + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_folder = out_dir + .join("build") + .join("_deps") + .join(format!("{}-build", dependency)); + + let dep_name = match build_profile.as_ref() { + "debug" => format!("{}d", dependency), + _ => String::from(dependency), + }; + println!("cargo:warning={}", dep_folder.display()); + println!("cargo:rustc-link-search=native={}", dep_folder.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information + // - tgi_trtllm_backend (i.e. FFI layer - src/ffi.cpp) + println!(r"cargo:rustc-link-search=native={}", backend_path.display()); + println!("cargo:rustc-link-lib=static=tgi_trtllm_backend"); + + // - tgi_trtllm_backend_impl (i.e. C++ code base to run inference include/backend.h) + println!( + r"cargo:rustc-link-search=native={}/build", + backend_path.display() + ); + println!("cargo:rustc-link-lib=static=tgi_trtllm_backend_impl"); } diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index d2b7f853..591e676a 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -13,21 +13,13 @@ //namespace tle = tensorrt_llm::executor; namespace huggingface::tgi::backends { - class TensorRtLlmBackendImpl { + class TensorRtLlmBackend { private: // tle::Executor executor; public: - TensorRtLlmBackendImpl(std::filesystem::path &engineFolder); + TensorRtLlmBackend(const std::filesystem::path &engineFolder); }; - - /*** - * - * @param engineFolder - * @return - */ - std::unique_ptr - create_trtllm_backend(std::filesystem::path &engineFolder); } #endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index ec447fbf..91df8451 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -3,11 +3,6 @@ #include "backend.h" -huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) { +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(const std::filesystem::path &engineFolder) { SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder); } - -std::unique_ptr -huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) { - return std::make_unique(engineFolder); -} diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 3b3b5bce..6f5c960c 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -1,19 +1,41 @@ +use std::path::Path; + +use async_trait::async_trait; +use cxx::UniquePtr; use tokio_stream::wrappers::UnboundedReceiverStream; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -pub struct TensorRtLLmBackend {} +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend}; -impl Backend for TensorRtLLmBackend { +pub struct TrtLLmBackend { + inner: UniquePtr, +} + +unsafe impl Sync for TrtLLmBackend {} +unsafe impl Send for TrtLLmBackend {} + +impl TrtLLmBackend { + pub fn new>(engine_folder: P) -> Result { + let engine_folder = engine_folder.as_ref(); + let inner = create_trtllm_backend(engine_folder.to_str().unwrap()); + + Ok(Self { inner }) + } +} + +#[async_trait] +impl Backend for TrtLLmBackend { fn schedule( &self, - request: ValidGenerateRequest, + _request: ValidGenerateRequest, ) -> Result>, InferError> { todo!() } - async fn health(&self, current_health: bool) -> bool { - todo!() + async fn health(&self, _current_health: bool) -> bool { + true } } diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 00000000..3d6c2033 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,13 @@ +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index 70f23394..215f602e 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -2,13 +2,19 @@ // Created by mfuntowicz on 6/30/24. // #include +#include "rust/cxx.h" +#include "backends/trtllm/include/backend.h" -namespace huggingface::tgi::backends::trtllm { - class TensorRtLlmBackend { - public: - TensorRtLlmBackend(std::filesystem::path engineFolder) { +namespace huggingface::tgi::backends { + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr create_trtllm_backend(rust::Str engineFolder) { + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + return std::make_unique(enginePath); + } - } - }; } \ No newline at end of file diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 0ac550cb..b2c6e45b 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,11 +1,15 @@ +pub use backend::TrtLLmBackend; + mod backend; +pub mod errors; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { unsafe extern "C++" { - include!("backends/trtllm/include/backend.h"); + include!("backends/trtllm/src/ffi.cpp"); - type TensorRtLlmBackendImpl; + type TensorRtLlmBackend; + fn create_trtllm_backend(engine_folder: &str) -> UniquePtr; } } diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 00000000..ee143072 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,137 @@ +use clap::Parser; + +use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend}; +use text_generation_router::server; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + messages_api_enabled, + max_client_batch_size, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + // Run server + let backend = TrtLLmBackend::new(model_id)?; + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + messages_api_enabled, + true, + max_client_batch_size, + ) + .await?; + Ok(()) +}