text-generation-inference/backends/llamacpp/build.rs
2025-02-04 13:32:58 +00:00

43 lines
1.5 KiB
Rust

use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardware: &str) {
let hardware_targets = HashMap::from([("cpu", None), ("cuda", Some(vec!["cuda"]))]);
if let Some(lib_search_path) = lib_search_path {
lib_search_path.split(":").for_each(|path| {
println!("cargo:rustc-link-search=dependency={path}");
});
}
if let Some(hardware_transient_deps) = hardware_targets.get(lib_target_hardware) {
if let Some(additional_transient_deps) = hardware_transient_deps {
additional_transient_deps.iter().for_each(|dep| {
println!("cargo:rustc-link-lib={dep}");
});
}
}
}
fn main() {
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
let lib_target_hardware = option_env!("TGI_LLAMA_HARDWARE_TARGET").unwrap_or("cpu");
let bindings = bindgen::Builder::default()
.header("src/wrapper.h")
.prepend_enum_name(false)
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
pkg_config::Config::new().probe("llama").unwrap();
inject_transient_dependencies(lib_search_path, lib_target_hardware);
}