diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs index def5acda..5665f161 100644 --- a/backends/trtllm/build.rs +++ b/backends/trtllm/build.rs @@ -12,7 +12,7 @@ const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); macro_rules! probe { - ($name: literal, $version: expr) => { + ($name: expr, $version: expr) => { if let Err(_) = pkg_config::probe_library($name) { pkg_config::probe_library(&format!("{}-{}", $name, $version)) .expect(&format!("Failed to locate {}", $name)); @@ -109,25 +109,34 @@ fn main() { probe!("ompi", MPI_REQUIRED_VERSION); // Probe CUDA & co. with pkg-config - probe!("cuda", CUDA_REQUIRED_VERSION); - probe!("cudart", CUDA_REQUIRED_VERSION); - probe!("cublas", CUDA_REQUIRED_VERSION); - probe!("nvidia-ml", CUDA_REQUIRED_VERSION); + const CUDA_TRANSITIVE_DEPS: [&str; 5] = ["cuda", "cudart", "cublas", "nvidia-ml", "nccl"]; + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); // TensorRT - println!( - r"cargo:rustc-link-search=native={}", - TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib") - ); + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); println!("cargo:rustc-link-lib=dylib=nvinfer"); // TensorRT-LLM - println!("cargo:rustc-link-lib=dylib=tensorrt_llm"); - println!("cargo:rustc-link-lib=static=tensorrt_llm_executor_static"); - println!("cargo:rustc-link-lib=dylib=nvinfer_plugin_tensorrt_llm"); - println!("cargo:rustc-link-lib=dylib=tensorrt_llm_nvrtc_wrapper"); + const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt-llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), + ]; + + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); // Backend - println!("cargo:rustc-link-lib=static=tgi_trtllm_backend_impl"); - println!("cargo:rustc-link-lib=static=tgi_trtllm_backend"); + const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); }