mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Working FFI call for TGI and TRTLLM backend
This commit is contained in:
parent
dc402dc9ac
commit
47ac5c654d
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -3460,10 +3460,12 @@ version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"clap",
|
||||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
@ -29,13 +29,12 @@ include_directories("${trtllm_SOURCE_DIR}/cpp/include")
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
|
||||
# TGI TRTLLM Backend definition
|
||||
add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp)
|
||||
|
||||
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp)
|
||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC spdlog::spdlog)
|
||||
|
||||
#### Unit Tests ####
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
|
@ -12,6 +12,8 @@ cxx = "1.0"
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
thiserror = "1.0.61"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
|
@ -1,19 +1,63 @@
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use cxx_build::CFG;
|
||||
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
|
||||
fn main() {
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let build_profile = env::var("PROFILE").unwrap();
|
||||
|
||||
// Build the backend implementation through CMake
|
||||
let backend_path = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.profile(match build_profile.as_ref() {
|
||||
"release" => "Release",
|
||||
_ => "Debug",
|
||||
})
|
||||
.build_target("tgi_trtllm_backend_impl")
|
||||
.build();
|
||||
|
||||
// Build the FFI layer calling the backend above
|
||||
CFG.include_prefix = "backends/trtllm";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.static_flag(true)
|
||||
.file("src/ffi.cpp")
|
||||
.std("c++20")
|
||||
.compile("tgi_trtllm_backend");
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
println!("cargo:rerun-if-changed=include/backend.h");
|
||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||
// println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl");
|
||||
println!("cargo:rerun-if-changed=src/ffi.cpp");
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||
let dep_folder = out_dir
|
||||
.join("build")
|
||||
.join("_deps")
|
||||
.join(format!("{}-build", dependency));
|
||||
|
||||
let dep_name = match build_profile.as_ref() {
|
||||
"debug" => format!("{}d", dependency),
|
||||
_ => String::from(dependency),
|
||||
};
|
||||
println!("cargo:warning={}", dep_folder.display());
|
||||
println!("cargo:rustc-link-search=native={}", dep_folder.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
// Emit linkage information
|
||||
// - tgi_trtllm_backend (i.e. FFI layer - src/ffi.cpp)
|
||||
println!(r"cargo:rustc-link-search=native={}", backend_path.display());
|
||||
println!("cargo:rustc-link-lib=static=tgi_trtllm_backend");
|
||||
|
||||
// - tgi_trtllm_backend_impl (i.e. C++ code base to run inference include/backend.h)
|
||||
println!(
|
||||
r"cargo:rustc-link-search=native={}/build",
|
||||
backend_path.display()
|
||||
);
|
||||
println!("cargo:rustc-link-lib=static=tgi_trtllm_backend_impl");
|
||||
}
|
||||
|
@ -13,21 +13,13 @@
|
||||
//namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
class TensorRtLlmBackendImpl {
|
||||
class TensorRtLlmBackend {
|
||||
private:
|
||||
// tle::Executor executor;
|
||||
|
||||
public:
|
||||
TensorRtLlmBackendImpl(std::filesystem::path &engineFolder);
|
||||
TensorRtLlmBackend(const std::filesystem::path &engineFolder);
|
||||
};
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
create_trtllm_backend(std::filesystem::path &engineFolder);
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_H
|
||||
|
@ -3,11 +3,6 @@
|
||||
|
||||
#include "backend.h"
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) {
|
||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(const std::filesystem::path &engineFolder) {
|
||||
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) {
|
||||
return std::make_unique<huggingface::tgi::backends::TensorRtLlmBackendImpl>(engineFolder);
|
||||
}
|
||||
|
@ -1,19 +1,41 @@
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
|
||||
pub struct TensorRtLLmBackend {}
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend};
|
||||
|
||||
impl Backend for TensorRtLLmBackend {
|
||||
pub struct TrtLLmBackend {
|
||||
inner: UniquePtr<TensorRtLlmBackend>,
|
||||
}
|
||||
|
||||
unsafe impl Sync for TrtLLmBackend {}
|
||||
unsafe impl Send for TrtLLmBackend {}
|
||||
|
||||
impl TrtLLmBackend {
|
||||
pub fn new<P: AsRef<Path>>(engine_folder: P) -> Result<Self, TensorRtLlmBackendError> {
|
||||
let engine_folder = engine_folder.as_ref();
|
||||
let inner = create_trtllm_backend(engine_folder.to_str().unwrap());
|
||||
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TrtLLmBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
_request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
todo!()
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
13
backends/trtllm/src/errors.rs
Normal file
13
backends/trtllm/src/errors.rs
Normal file
@ -0,0 +1,13 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use text_generation_router::server;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRtLlmBackendError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
@ -2,13 +2,19 @@
|
||||
// Created by mfuntowicz on 6/30/24.
|
||||
//
|
||||
#include <filesystem>
|
||||
#include "rust/cxx.h"
|
||||
|
||||
#include "backends/trtllm/include/backend.h"
|
||||
|
||||
namespace huggingface::tgi::backends::trtllm {
|
||||
class TensorRtLlmBackend {
|
||||
public:
|
||||
TensorRtLlmBackend(std::filesystem::path engineFolder) {
|
||||
namespace huggingface::tgi::backends {
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<TensorRtLlmBackend> create_trtllm_backend(rust::Str engineFolder) {
|
||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||
return std::make_unique<TensorRtLlmBackend>(enginePath);
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
}
|
@ -1,11 +1,15 @@
|
||||
pub use backend::TrtLLmBackend;
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/include/backend.h");
|
||||
include!("backends/trtllm/src/ffi.cpp");
|
||||
|
||||
type TensorRtLlmBackendImpl;
|
||||
type TensorRtLlmBackend;
|
||||
|
||||
fn create_trtllm_backend(engine_folder: &str) -> UniquePtr<TensorRtLlmBackend>;
|
||||
}
|
||||
}
|
||||
|
137
backends/trtllm/src/main.rs
Normal file
137
backends/trtllm/src/main.rs
Normal file
@ -0,0 +1,137 @@
|
||||
use clap::Parser;
|
||||
|
||||
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
|
||||
use text_generation_router::server;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
hostname,
|
||||
port,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
model_id,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
messages_api_enabled,
|
||||
max_client_batch_size,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
// Run server
|
||||
let backend = TrtLLmBackend::new(model_id)?;
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
messages_api_enabled,
|
||||
true,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user