Rename bindings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-05 11:13:17 +00:00
parent c52f08351f
commit 051ff2d5ce
No known key found for this signature in database
2 changed files with 112 additions and 100 deletions

View File

@ -1,3 +1,5 @@
use bindgen::callbacks::{ParseCallbacks, ItemInfo};
use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
@ -20,6 +22,15 @@ fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardw
}
}
#[derive(Debug)]
struct PrefixStripper;
impl ParseCallbacks for PrefixStripper {
fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
item_info.name.strip_prefix("llama_").map(str::to_string)
}
}
fn main() {
let pkg_cuda = option_env!("TGI_LLAMA_PKG_CUDA");
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
@ -28,13 +39,14 @@ fn main() {
let bindings = bindgen::Builder::default()
.header("src/wrapper.h")
.prepend_enum_name(false)
.parse_callbacks(Box::new(PrefixStripper))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.write_to_file(out_path.join("llamacpp.rs"))
.expect("Couldn't write bindings!");
if let Some(pkg_cuda) = pkg_cuda {

View File

@ -1,9 +1,9 @@
mod bindings {
mod llamacpp {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
}
use async_trait::async_trait;
use std::ffi::CString;
@ -91,39 +91,39 @@ pub enum LlamacppGGMLType {
// TODO: macro
impl LlamacppGGMLType {
fn to_ggml_type(&self) -> bindings::ggml_type {
fn to_ggml_type(&self) -> llamacpp::ggml_type {
match self {
LlamacppGGMLType::F32 => bindings::GGML_TYPE_F32,
LlamacppGGMLType::F16 => bindings::GGML_TYPE_F16,
LlamacppGGMLType::Q4_0 => bindings::GGML_TYPE_Q4_0,
LlamacppGGMLType::Q4_1 => bindings::GGML_TYPE_Q4_1,
LlamacppGGMLType::Q5_0 => bindings::GGML_TYPE_Q5_0,
LlamacppGGMLType::Q5_1 => bindings::GGML_TYPE_Q5_1,
LlamacppGGMLType::Q8_0 => bindings::GGML_TYPE_Q8_0,
LlamacppGGMLType::Q8_1 => bindings::GGML_TYPE_Q8_1,
LlamacppGGMLType::Q2_K => bindings::GGML_TYPE_Q2_K,
LlamacppGGMLType::Q3_K => bindings::GGML_TYPE_Q3_K,
LlamacppGGMLType::Q4_K => bindings::GGML_TYPE_Q4_K,
LlamacppGGMLType::Q5_K => bindings::GGML_TYPE_Q5_K,
LlamacppGGMLType::Q6_K => bindings::GGML_TYPE_Q6_K,
LlamacppGGMLType::Q8_K => bindings::GGML_TYPE_Q8_K,
LlamacppGGMLType::IQ2_XXS => bindings::GGML_TYPE_IQ2_XXS,
LlamacppGGMLType::IQ2_XS => bindings::GGML_TYPE_IQ2_XS,
LlamacppGGMLType::IQ3_XXS => bindings::GGML_TYPE_IQ3_XXS,
LlamacppGGMLType::IQ1_S => bindings::GGML_TYPE_IQ1_S,
LlamacppGGMLType::IQ4_NL => bindings::GGML_TYPE_IQ4_NL,
LlamacppGGMLType::IQ3_S => bindings::GGML_TYPE_IQ3_S,
LlamacppGGMLType::IQ2_S => bindings::GGML_TYPE_IQ2_S,
LlamacppGGMLType::IQ4_XS => bindings::GGML_TYPE_IQ4_XS,
LlamacppGGMLType::I8 => bindings::GGML_TYPE_I8,
LlamacppGGMLType::I16 => bindings::GGML_TYPE_I16,
LlamacppGGMLType::I32 => bindings::GGML_TYPE_I32,
LlamacppGGMLType::I64 => bindings::GGML_TYPE_I64,
LlamacppGGMLType::F64 => bindings::GGML_TYPE_F64,
LlamacppGGMLType::IQ1_M => bindings::GGML_TYPE_IQ1_M,
LlamacppGGMLType::BF16 => bindings::GGML_TYPE_BF16,
LlamacppGGMLType::TQ1_0 => bindings::GGML_TYPE_TQ1_0,
LlamacppGGMLType::TQ2_0 => bindings::GGML_TYPE_TQ2_0,
LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
}
}
}
@ -201,16 +201,16 @@ impl LlamacppRequest {
}
struct Llamacpp {
model: *mut bindings::llama_model,
ctx: *mut bindings::llama_context,
vocab: *const bindings::llama_vocab,
logprobs: Vec<bindings::llama_token_data>,
batch: bindings::llama_batch,
model: *mut llamacpp::llama_model,
ctx: *mut llamacpp::llama_context,
vocab: *const llamacpp::llama_vocab,
logprobs: Vec<llamacpp::llama_token_data>,
batch: llamacpp::llama_batch,
n_ctx: u32,
}
extern "C" fn llamacpp_log_callback(
level: bindings::ggml_log_level,
level: llamacpp::ggml_log_level,
msg: *const std::os::raw::c_char,
_user_data: *mut std::os::raw::c_void,
) {
@ -218,10 +218,10 @@ extern "C" fn llamacpp_log_callback(
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
match level {
bindings::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
bindings::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
bindings::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
bindings::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
_ => trace!(target: "llamacpp", "{}", rmsg),
}
}
@ -231,12 +231,12 @@ impl Llamacpp {
let gguf = CString::new(conf.model_gguf)?;
let model = unsafe {
let mut params = bindings::llama_model_default_params();
let mut params = llamacpp::model_default_params();
params.n_gpu_layers = conf.n_gpu_layers as _;
params.split_mode = match conf.split_mode {
LlamacppSplitMode::GPU(_) => bindings::LLAMA_SPLIT_MODE_NONE,
LlamacppSplitMode::Layer => bindings::LLAMA_SPLIT_MODE_LAYER,
LlamacppSplitMode::Row => bindings::LLAMA_SPLIT_MODE_ROW,
LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
};
params.main_gpu = match conf.split_mode {
LlamacppSplitMode::GPU(n) => n as _,
@ -244,13 +244,13 @@ impl Llamacpp {
};
params.use_mmap = conf.use_mmap;
params.use_mlock = conf.use_mlock;
bindings::llama_model_load_from_file(gguf.as_ptr(), params)
llamacpp::model_load_from_file(gguf.as_ptr(), params)
};
if model.is_null() {
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
}
let ctx = unsafe {
let mut params = bindings::llama_context_default_params();
let mut params = llamacpp::context_default_params();
params.n_ctx = conf.n_ctx as _;
params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
@ -263,48 +263,48 @@ impl Llamacpp {
params.type_k = conf.type_k.to_ggml_type();
params.type_v = conf.type_v.to_ggml_type();
params.no_perf = true;
bindings::llama_init_from_model(model, params)
llamacpp::init_from_model(model, params)
};
if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
}
let n_ctx = unsafe { bindings::llama_n_ctx(ctx) };
let n_ctx = unsafe { llamacpp::n_ctx(ctx) };
let vocab = unsafe {
bindings::llama_model_get_vocab(model)
llamacpp::model_get_vocab(model)
};
if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
}
let n_tokens = unsafe {
bindings::llama_vocab_n_tokens(vocab)
llamacpp::vocab_n_tokens(vocab)
};
let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens {
logprobs.push(bindings::llama_token_data {
logprobs.push(llamacpp::llama_token_data {
id: token,
logit: 0.0,
p: 0.0,
});
}
let batch = unsafe {
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1)
};
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
}
fn clear_kv_cache(&mut self, seq_id: bindings::llama_seq_id) {
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
unsafe {
bindings::llama_kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
}
}
fn batch_push(
&mut self,
token: bindings::llama_token,
pos: bindings::llama_pos,
seq_id: bindings::llama_seq_id,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
seq_id: llamacpp::llama_seq_id,
logits: bool,
) -> usize {
let n = self.batch.n_tokens as usize;
@ -323,43 +323,43 @@ impl Llamacpp {
impl Drop for Llamacpp {
fn drop(&mut self) {
if !self.ctx.is_null() {
unsafe { bindings::llama_free(self.ctx) };
unsafe { llamacpp::free(self.ctx) };
}
if !self.model.is_null() {
unsafe { bindings::llama_model_free(self.model) };
unsafe { llamacpp::model_free(self.model) };
}
unsafe { bindings::llama_batch_free(self.batch) };
unsafe { llamacpp::batch_free(self.batch) };
}
}
struct LlamacppSampler {
chain: *mut bindings::llama_sampler,
chain: *mut llamacpp::llama_sampler,
}
impl LlamacppSampler {
fn new(req: &LlamacppRequest) -> Option<Self> {
let chain = unsafe {
let params = bindings::llama_sampler_chain_default_params();
bindings::llama_sampler_chain_init(params)
let params = llamacpp::sampler_chain_default_params();
llamacpp::sampler_chain_init(params)
};
if chain.is_null() {
error!("Failed to init sampler");
return None;
}
let top_k = unsafe {
bindings::llama_sampler_init_top_k(req.top_k)
llamacpp::sampler_init_top_k(req.top_k)
};
let top_p = unsafe {
bindings::llama_sampler_init_top_p(req.top_p, req.min_keep)
llamacpp::sampler_init_top_p(req.top_p, req.min_keep)
};
let typical_p = unsafe {
bindings::llama_sampler_init_typical(req.typical_p, req.min_keep)
llamacpp::sampler_init_typical(req.typical_p, req.min_keep)
};
let temp = unsafe {
bindings::llama_sampler_init_temp(req.temp)
llamacpp::sampler_init_temp(req.temp)
};
let penalties = unsafe {
bindings::llama_sampler_init_penalties(
llamacpp::sampler_init_penalties(
req.penalty_last_n,
req.penalty_repeat,
req.penalty_freq,
@ -367,7 +367,7 @@ impl LlamacppSampler {
)
};
let dist = unsafe {
bindings::llama_sampler_init_dist(req.seed)
llamacpp::sampler_init_dist(req.seed)
};
let mut failed = false;
@ -381,7 +381,7 @@ impl LlamacppSampler {
error!("Failed to init {k} sampler");
failed = true;
} else {
unsafe { bindings::llama_sampler_chain_add(chain, *v) };
unsafe { llamacpp::sampler_chain_add(chain, *v) };
}
}
if failed {
@ -391,27 +391,27 @@ impl LlamacppSampler {
}
}
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) {
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
let logits = unsafe {
bindings::llama_get_logits_ith(llamacpp.ctx, idx as _)
llamacpp::get_logits_ith(llamacpp.ctx, idx as _)
};
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = bindings::llama_token_data {
*logprob = llamacpp::llama_token_data {
id: token as _,
logit: unsafe { *logits.offset(token as _) },
p: 0.0,
};
}
let mut view = bindings::llama_token_data_array {
let mut view = llamacpp::llama_token_data_array {
data: llamacpp.logprobs.as_mut_ptr(),
size: llamacpp.logprobs.len(),
selected: -1,
sorted: false,
};
unsafe {
bindings::llama_sampler_apply(self.chain, &mut view);
llamacpp::sampler_apply(self.chain, &mut view);
let logprob = *view.data.offset(view.selected as _);
bindings::llama_sampler_accept(self.chain, logprob.id);
llamacpp::sampler_accept(self.chain, logprob.id);
(logprob.id, logprob.p.ln())
}
}
@ -420,7 +420,7 @@ impl LlamacppSampler {
impl Drop for LlamacppSampler {
fn drop(&mut self) {
if !self.chain.is_null() {
unsafe { bindings::llama_sampler_free(self.chain) };
unsafe { llamacpp::sampler_free(self.chain) };
}
}
}
@ -428,8 +428,8 @@ impl Drop for LlamacppSampler {
struct LlamacppSeq {
id: usize,
batch_pos: usize,
token: bindings::llama_token,
pos: bindings::llama_pos,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
sampler: LlamacppSampler,
text: String,
n_new_tokens: usize,
@ -446,14 +446,14 @@ impl LlamacppBackend {
// Setup llama & export logs, once and for all
INIT.call_once(|| unsafe {
bindings::llama_log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
bindings::llama_backend_init();
bindings::llama_numa_init(match conf.numa {
LlamacppNuma::Disabled => bindings::GGML_NUMA_STRATEGY_DISABLED,
LlamacppNuma::Distribute => bindings::GGML_NUMA_STRATEGY_DISTRIBUTE,
LlamacppNuma::Isolate => bindings::GGML_NUMA_STRATEGY_ISOLATE,
LlamacppNuma::Numactl => bindings::GGML_NUMA_STRATEGY_NUMACTL,
LlamacppNuma::Mirror => bindings::GGML_NUMA_STRATEGY_MIRROR,
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
llamacpp::backend_init();
llamacpp::numa_init(match conf.numa {
LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
});
});
@ -526,17 +526,17 @@ impl LlamacppBackend {
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as bindings::llama_token,
pos as bindings::llama_pos,
seq_id as bindings::llama_seq_id,
token_id as llamacpp::llama_token,
pos as llamacpp::llama_pos,
seq_id as llamacpp::llama_seq_id,
pos == last_pos, // check samplers
);
}
seqs.push(LlamacppSeq {
id: seq_id,
batch_pos: llamacpp.batch.n_tokens as usize - 1,
token: bindings::LLAMA_TOKEN_NULL,
pos: last_pos as bindings::llama_pos + 1,
token: llamacpp::LLAMA_TOKEN_NULL,
pos: last_pos as llamacpp::llama_pos + 1,
sampler: sampler,
text: String::with_capacity(1024),
n_new_tokens: 0,
@ -548,7 +548,7 @@ impl LlamacppBackend {
break;
}
let decode = unsafe {
bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
llamacpp::decode(llamacpp.ctx, llamacpp.batch)
};
if decode != 0 {
warn!("llama_decode failed, clearing kv cache");
@ -560,7 +560,7 @@ impl LlamacppBackend {
break;
}
let kv_cache_used_cells = unsafe {
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
llamacpp::get_kv_cache_used_cells(llamacpp.ctx)
};
for seq in seqs.iter_mut() {
if !seq.running {
@ -591,7 +591,7 @@ impl LlamacppBackend {
special: special,
};
let finish: Option<FinishReason> = {
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken)
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
Some(FinishReason::Length)