diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs index 19b1987d..499583cd 100644 --- a/backends/llamacpp/build.rs +++ b/backends/llamacpp/build.rs @@ -1,26 +1,7 @@ use bindgen::callbacks::{ItemInfo, ParseCallbacks}; -use std::collections::HashMap; use std::env; use std::path::PathBuf; -fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardware: &str) { - let hardware_targets = HashMap::from([("cpu", None), ("cuda", Some(vec!["cuda"]))]); - - if let Some(lib_search_path) = lib_search_path { - lib_search_path.split(":").for_each(|path| { - println!("cargo:rustc-link-search=dependency={path}"); - }); - } - - if let Some(hardware_transient_deps) = hardware_targets.get(lib_target_hardware) { - if let Some(additional_transient_deps) = hardware_transient_deps { - additional_transient_deps.iter().for_each(|dep| { - println!("cargo:rustc-link-lib={dep}"); - }); - } - } -} - #[derive(Debug)] struct PrefixStripper; @@ -31,9 +12,6 @@ impl ParseCallbacks for PrefixStripper { } fn main() { - let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH"); - let lib_target_hardware = option_env!("TGI_LLAMA_HARDWARE_TARGET").unwrap_or("cpu"); - if let Some(cuda_version) = option_env!("CUDA_VERSION") { let mut version: Vec<&str> = cuda_version.split('.').collect(); if version.len() > 2 { @@ -42,10 +20,21 @@ fn main() { let cuda_version = format!("cuda-{}", version.join(".")); pkg_config::Config::new().probe(&cuda_version).unwrap(); } - pkg_config::Config::new().probe("llama").unwrap(); + let llama = pkg_config::Config::new().probe("llama").unwrap(); + + for path in &llama.link_paths { + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display()); + } + println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags"); let bindings = bindgen::Builder::default() - .header("src/wrapper.h") + .clang_args( + llama + .include_paths + .iter() + .map(|p| format!("-I{}", p.display())), + ) + .header_contents("llama_bindings.h", "#include ") .prepend_enum_name(false) .parse_callbacks(Box::new(PrefixStripper)) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) @@ -56,6 +45,4 @@ fn main() { bindings .write_to_file(out_path.join("llamacpp.rs")) .expect("Couldn't write bindings!"); - - inject_transient_dependencies(lib_search_path, lib_target_hardware); } diff --git a/backends/llamacpp/src/wrapper.h b/backends/llamacpp/src/wrapper.h deleted file mode 100644 index 630ebeec..00000000 --- a/backends/llamacpp/src/wrapper.h +++ /dev/null @@ -1 +0,0 @@ -#include