mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Handle custom llama.cpp dir
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
b6cfa0fbc0
commit
6bdb644f2c
@ -1,26 +1,7 @@
|
|||||||
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
|
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardware: &str) {
|
|
||||||
let hardware_targets = HashMap::from([("cpu", None), ("cuda", Some(vec!["cuda"]))]);
|
|
||||||
|
|
||||||
if let Some(lib_search_path) = lib_search_path {
|
|
||||||
lib_search_path.split(":").for_each(|path| {
|
|
||||||
println!("cargo:rustc-link-search=dependency={path}");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(hardware_transient_deps) = hardware_targets.get(lib_target_hardware) {
|
|
||||||
if let Some(additional_transient_deps) = hardware_transient_deps {
|
|
||||||
additional_transient_deps.iter().for_each(|dep| {
|
|
||||||
println!("cargo:rustc-link-lib={dep}");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct PrefixStripper;
|
struct PrefixStripper;
|
||||||
|
|
||||||
@ -31,9 +12,6 @@ impl ParseCallbacks for PrefixStripper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
|
|
||||||
let lib_target_hardware = option_env!("TGI_LLAMA_HARDWARE_TARGET").unwrap_or("cpu");
|
|
||||||
|
|
||||||
if let Some(cuda_version) = option_env!("CUDA_VERSION") {
|
if let Some(cuda_version) = option_env!("CUDA_VERSION") {
|
||||||
let mut version: Vec<&str> = cuda_version.split('.').collect();
|
let mut version: Vec<&str> = cuda_version.split('.').collect();
|
||||||
if version.len() > 2 {
|
if version.len() > 2 {
|
||||||
@ -42,10 +20,21 @@ fn main() {
|
|||||||
let cuda_version = format!("cuda-{}", version.join("."));
|
let cuda_version = format!("cuda-{}", version.join("."));
|
||||||
pkg_config::Config::new().probe(&cuda_version).unwrap();
|
pkg_config::Config::new().probe(&cuda_version).unwrap();
|
||||||
}
|
}
|
||||||
pkg_config::Config::new().probe("llama").unwrap();
|
let llama = pkg_config::Config::new().probe("llama").unwrap();
|
||||||
|
|
||||||
|
for path in &llama.link_paths {
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
|
||||||
|
|
||||||
let bindings = bindgen::Builder::default()
|
let bindings = bindgen::Builder::default()
|
||||||
.header("src/wrapper.h")
|
.clang_args(
|
||||||
|
llama
|
||||||
|
.include_paths
|
||||||
|
.iter()
|
||||||
|
.map(|p| format!("-I{}", p.display())),
|
||||||
|
)
|
||||||
|
.header_contents("llama_bindings.h", "#include <llama.h>")
|
||||||
.prepend_enum_name(false)
|
.prepend_enum_name(false)
|
||||||
.parse_callbacks(Box::new(PrefixStripper))
|
.parse_callbacks(Box::new(PrefixStripper))
|
||||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
||||||
@ -56,6 +45,4 @@ fn main() {
|
|||||||
bindings
|
bindings
|
||||||
.write_to_file(out_path.join("llamacpp.rs"))
|
.write_to_file(out_path.join("llamacpp.rs"))
|
||||||
.expect("Couldn't write bindings!");
|
.expect("Couldn't write bindings!");
|
||||||
|
|
||||||
inject_transient_dependencies(lib_search_path, lib_target_hardware);
|
|
||||||
}
|
}
|
||||||
|
@ -1 +0,0 @@
|
|||||||
#include <llama.h>
|
|
Loading…
Reference in New Issue
Block a user