Working FFI call for TGI and TRTLLM backend

This commit is contained in:
Morgan Funtowicz 2024-07-01 15:53:23 +02:00
parent dc402dc9ac
commit 47ac5c654d
11 changed files with 249 additions and 33 deletions

2
Cargo.lock generated
View File

@ -3460,10 +3460,12 @@ version = "2.0.5-dev0"
dependencies = [
"async-stream",
"async-trait",
"clap",
"cmake",
"cxx",
"cxx-build",
"text-generation-router",
"thiserror",
"tokio",
"tokio-stream",
]

View File

@ -29,13 +29,12 @@ include_directories("${trtllm_SOURCE_DIR}/cpp/include")
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp)
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp)
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC spdlog::spdlog)
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})

View File

@ -12,6 +12,8 @@ cxx = "1.0"
text-generation-router = { path = "../../router" }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
clap = { version = "4.5.4", features = ["derive"] }
thiserror = "1.0.61"
[build-dependencies]
cmake = "0.1"

View File

@ -1,19 +1,63 @@
use std::env;
use std::path::PathBuf;
use cxx_build::CFG;
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
fn main() {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let build_profile = env::var("PROFILE").unwrap();
// Build the backend implementation through CMake
let backend_path = cmake::Config::new(".")
.uses_cxx11()
.generator("Ninja")
.profile(match build_profile.as_ref() {
"release" => "Release",
_ => "Debug",
})
.build_target("tgi_trtllm_backend_impl")
.build();
// Build the FFI layer calling the backend above
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
.file("src/ffi.cpp")
.std("c++20")
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
// println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl");
println!("cargo:rerun-if-changed=src/ffi.cpp");
// Additional transitive CMake dependencies
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_folder = out_dir
.join("build")
.join("_deps")
.join(format!("{}-build", dependency));
let dep_name = match build_profile.as_ref() {
"debug" => format!("{}d", dependency),
_ => String::from(dependency),
};
println!("cargo:warning={}", dep_folder.display());
println!("cargo:rustc-link-search=native={}", dep_folder.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
// Emit linkage information
// - tgi_trtllm_backend (i.e. FFI layer - src/ffi.cpp)
println!(r"cargo:rustc-link-search=native={}", backend_path.display());
println!("cargo:rustc-link-lib=static=tgi_trtllm_backend");
// - tgi_trtllm_backend_impl (i.e. C++ code base to run inference include/backend.h)
println!(
r"cargo:rustc-link-search=native={}/build",
backend_path.display()
);
println!("cargo:rustc-link-lib=static=tgi_trtllm_backend_impl");
}

View File

@ -13,21 +13,13 @@
//namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl {
class TensorRtLlmBackend {
private:
// tle::Executor executor;
public:
TensorRtLlmBackendImpl(std::filesystem::path &engineFolder);
TensorRtLlmBackend(const std::filesystem::path &engineFolder);
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
create_trtllm_backend(std::filesystem::path &engineFolder);
}
#endif //TGI_TRTLLM_BACKEND_H

View File

@ -3,11 +3,6 @@
#include "backend.h"
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) {
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(const std::filesystem::path &engineFolder) {
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) {
return std::make_unique<huggingface::tgi::backends::TensorRtLlmBackendImpl>(engineFolder);
}

View File

@ -1,19 +1,41 @@
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use tokio_stream::wrappers::UnboundedReceiverStream;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
pub struct TensorRtLLmBackend {}
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend};
impl Backend for TensorRtLLmBackend {
pub struct TrtLLmBackend {
inner: UniquePtr<TensorRtLlmBackend>,
}
unsafe impl Sync for TrtLLmBackend {}
unsafe impl Send for TrtLLmBackend {}
impl TrtLLmBackend {
pub fn new<P: AsRef<Path>>(engine_folder: P) -> Result<Self, TensorRtLlmBackendError> {
let engine_folder = engine_folder.as_ref();
let inner = create_trtllm_backend(engine_folder.to_str().unwrap());
Ok(Self { inner })
}
}
#[async_trait]
impl Backend for TrtLLmBackend {
fn schedule(
&self,
request: ValidGenerateRequest,
_request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
todo!()
}
async fn health(&self, current_health: bool) -> bool {
todo!()
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -0,0 +1,13 @@
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -2,13 +2,19 @@
// Created by mfuntowicz on 6/30/24.
//
#include <filesystem>
#include "rust/cxx.h"
#include "backends/trtllm/include/backend.h"
namespace huggingface::tgi::backends::trtllm {
class TensorRtLlmBackend {
public:
TensorRtLlmBackend(std::filesystem::path engineFolder) {
namespace huggingface::tgi::backends {
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackend> create_trtllm_backend(rust::Str engineFolder) {
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
return std::make_unique<TensorRtLlmBackend>(enginePath);
}
};
}

View File

@ -1,11 +1,15 @@
pub use backend::TrtLLmBackend;
mod backend;
pub mod errors;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
unsafe extern "C++" {
include!("backends/trtllm/include/backend.h");
include!("backends/trtllm/src/ffi.cpp");
type TensorRtLlmBackendImpl;
type TensorRtLlmBackend;
fn create_trtllm_backend(engine_folder: &str) -> UniquePtr<TensorRtLlmBackend>;
}
}

137
backends/trtllm/src/main.rs Normal file
View File

@ -0,0 +1,137 @@
use clap::Parser;
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
use text_generation_router::server;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
model_id: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}
#[tokio::main]
async fn main() -> Result<(), TensorRtLlmBackendError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
hostname,
port,
tokenizer_name,
tokenizer_config_path,
revision,
model_id,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
} = args;
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// Run server
let backend = TrtLLmBackend::new(model_id)?;
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
)
.await?;
Ok(())
}