Handle custom llama.cpp dir

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-07 12:08:02 +00:00
parent b6cfa0fbc0
commit 6bdb644f2c
No known key found for this signature in database
2 changed files with 13 additions and 27 deletions

View File

@ -1,26 +1,7 @@
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardware: &str) {
let hardware_targets = HashMap::from([("cpu", None), ("cuda", Some(vec!["cuda"]))]);
if let Some(lib_search_path) = lib_search_path {
lib_search_path.split(":").for_each(|path| {
println!("cargo:rustc-link-search=dependency={path}");
});
}
if let Some(hardware_transient_deps) = hardware_targets.get(lib_target_hardware) {
if let Some(additional_transient_deps) = hardware_transient_deps {
additional_transient_deps.iter().for_each(|dep| {
println!("cargo:rustc-link-lib={dep}");
});
}
}
}
#[derive(Debug)]
struct PrefixStripper;
@ -31,9 +12,6 @@ impl ParseCallbacks for PrefixStripper {
}
fn main() {
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
let lib_target_hardware = option_env!("TGI_LLAMA_HARDWARE_TARGET").unwrap_or("cpu");
if let Some(cuda_version) = option_env!("CUDA_VERSION") {
let mut version: Vec<&str> = cuda_version.split('.').collect();
if version.len() > 2 {
@ -42,10 +20,21 @@ fn main() {
let cuda_version = format!("cuda-{}", version.join("."));
pkg_config::Config::new().probe(&cuda_version).unwrap();
}
pkg_config::Config::new().probe("llama").unwrap();
let llama = pkg_config::Config::new().probe("llama").unwrap();
for path in &llama.link_paths {
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
}
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
let bindings = bindgen::Builder::default()
.header("src/wrapper.h")
.clang_args(
llama
.include_paths
.iter()
.map(|p| format!("-I{}", p.display())),
)
.header_contents("llama_bindings.h", "#include <llama.h>")
.prepend_enum_name(false)
.parse_callbacks(Box::new(PrefixStripper))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
@ -56,6 +45,4 @@ fn main() {
bindings
.write_to_file(out_path.join("llamacpp.rs"))
.expect("Couldn't write bindings!");
inject_transient_dependencies(lib_search_path, lib_target_hardware);
}

View File

@ -1 +0,0 @@
#include <llama.h>