diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index f07c26d1..a5f09c11 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -22,4 +22,5 @@ log = { version = "0.4", features = [] } [build-dependencies] cmake = "0.1" -cxx-build = "1.0" \ No newline at end of file +cxx-build = "1.0" +pkg-config = "0.3" \ No newline at end of file diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs index 0fa62ec9..4c5e6d80 100644 --- a/backends/trtllm/build.rs +++ b/backends/trtllm/build.rs @@ -2,11 +2,24 @@ use std::env; use std::path::PathBuf; use cxx_build::CFG; +use pkg_config; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.4"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +macro_rules! probe { + ($name: literal, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("cuda-{}", $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + fn main() { // Misc variables let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); @@ -24,11 +37,13 @@ fn main() { true => "Debug", false => "Release", }) + .env("OPT_LEVEL", "3") + .out_dir(INSTALL_PREFIX.unwrap_or("/usr/local/tgi")) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define( "TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", CUDA_ARCH_LIST.unwrap_or("90-real"), // Hopper by default ) - .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define( "TGI_TRTLLM_BACKEND_TRT_ROOT", TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"), @@ -37,7 +52,6 @@ fn main() { // Additional transitive CMake dependencies let deps_folder = out_dir.join("build").join("_deps"); - for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { let dep_name = match build_profile.as_ref() { "debug" => format!("{}d", dependency), @@ -68,55 +82,30 @@ fn main() { println!("cargo:rerun-if-changed=include/ffi.h"); println!("cargo:rerun-if-changed=src/ffi.cpp"); - // Emit linkage information - // - tgi_trtllm_backend (i.e. FFI layer - src/ffi.cpp) - let trtllm_lib_path = deps_folder - .join("trtllm-src") - .join("cpp") - .join("tensorrt_llm"); + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); - let trtllm_executor_linker_search_path = - trtllm_lib_path.join("executor").join("x86_64-linux-gnu"); + // Probe CUDA & co. with pkg-config + probe!("cuda", CUDA_REQUIRED_VERSION); + probe!("cudart", CUDA_REQUIRED_VERSION); + probe!("cublas", CUDA_REQUIRED_VERSION); + probe!("nvidia-ml", CUDA_REQUIRED_VERSION); - // TRTLLM libtensorrt_llm_nvrtc_wrapper.so - let trtllm_nvrtc_linker_search_path = trtllm_lib_path - .join("kernels") - .join("decoderMaskedMultiheadAttention") - .join("decoderXQAImplJIT") - .join("nvrtcWrapper") - .join("x86_64-linux-gnu"); - - println!(r"cargo:rustc-link-search=native=/usr/local/cuda/lib64"); - println!(r"cargo:rustc-link-search=native=/usr/local/cuda/lib64/stubs"); + // TensorRT println!(r"cargo:rustc-link-search=native=/usr/local/tensorrt/lib"); println!(r"cargo:rustc-link-search=native={}", backend_path.display()); - // println!( - // r"cargo:rustc-link-search=native={}/build", - // backend_path.display() - // ); + + // TensorRT-LLM println!( r"cargo:rustc-link-search=native={}", backend_path.join("lib").display() ); - println!( - r"cargo:rustc-link-search=native={}", - trtllm_executor_linker_search_path.display() - ); - println!( - r"cargo:rustc-link-search=native={}", - trtllm_nvrtc_linker_search_path.display() - ); - println!("cargo:rustc-link-lib=dylib=cuda"); - println!("cargo:rustc-link-lib=dylib=cudart"); - println!("cargo:rustc-link-lib=dylib=cublas"); - println!("cargo:rustc-link-lib=dylib=cublasLt"); - println!("cargo:rustc-link-lib=dylib=mpi"); - println!("cargo:rustc-link-lib=dylib=nvidia-ml"); - println!("cargo:rustc-link-lib=dylib=nvinfer"); + println!("cargo:rustc-link-lib=dylib=tensorrt_llm"); + println!("cargo:rustc-link-lib=static=tensorrt_llm_executor_static"); println!("cargo:rustc-link-lib=dylib=nvinfer_plugin_tensorrt_llm"); println!("cargo:rustc-link-lib=dylib=tensorrt_llm_nvrtc_wrapper"); - println!("cargo:rustc-link-lib=static=tensorrt_llm_executor_static"); - println!("cargo:rustc-link-lib=dylib=tensorrt_llm"); + + // Backend println!("cargo:rustc-link-lib=static=tgi_trtllm_backend_impl"); println!("cargo:rustc-link-lib=static=tgi_trtllm_backend"); }