From 051ff2d5ce442ebb14f1abc796438a3087949341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 5 Feb 2025 11:13:17 +0000 Subject: [PATCH] Rename bindings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/build.rs | 14 ++- backends/llamacpp/src/backend.rs | 198 +++++++++++++++---------------- 2 files changed, 112 insertions(+), 100 deletions(-) diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs index e56272ee..1b1c3718 100644 --- a/backends/llamacpp/build.rs +++ b/backends/llamacpp/build.rs @@ -1,3 +1,5 @@ + +use bindgen::callbacks::{ParseCallbacks, ItemInfo}; use std::collections::HashMap; use std::env; use std::path::PathBuf; @@ -20,6 +22,15 @@ fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardw } } +#[derive(Debug)] +struct PrefixStripper; + +impl ParseCallbacks for PrefixStripper { + fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option { + item_info.name.strip_prefix("llama_").map(str::to_string) + } +} + fn main() { let pkg_cuda = option_env!("TGI_LLAMA_PKG_CUDA"); let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH"); @@ -28,13 +39,14 @@ fn main() { let bindings = bindgen::Builder::default() .header("src/wrapper.h") .prepend_enum_name(false) + .parse_callbacks(Box::new(PrefixStripper)) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate() .expect("Unable to generate bindings"); let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); bindings - .write_to_file(out_path.join("bindings.rs")) + .write_to_file(out_path.join("llamacpp.rs")) .expect("Couldn't write bindings!"); if let Some(pkg_cuda) = pkg_cuda { diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index c6f4e925..aa44df31 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,9 +1,9 @@ -mod bindings { +mod llamacpp { #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] #![allow(dead_code)] - include!(concat!(env!("OUT_DIR"), "/bindings.rs")); + include!(concat!(env!("OUT_DIR"), "/llamacpp.rs")); } use async_trait::async_trait; use std::ffi::CString; @@ -91,39 +91,39 @@ pub enum LlamacppGGMLType { // TODO: macro impl LlamacppGGMLType { - fn to_ggml_type(&self) -> bindings::ggml_type { + fn to_ggml_type(&self) -> llamacpp::ggml_type { match self { - LlamacppGGMLType::F32 => bindings::GGML_TYPE_F32, - LlamacppGGMLType::F16 => bindings::GGML_TYPE_F16, - LlamacppGGMLType::Q4_0 => bindings::GGML_TYPE_Q4_0, - LlamacppGGMLType::Q4_1 => bindings::GGML_TYPE_Q4_1, - LlamacppGGMLType::Q5_0 => bindings::GGML_TYPE_Q5_0, - LlamacppGGMLType::Q5_1 => bindings::GGML_TYPE_Q5_1, - LlamacppGGMLType::Q8_0 => bindings::GGML_TYPE_Q8_0, - LlamacppGGMLType::Q8_1 => bindings::GGML_TYPE_Q8_1, - LlamacppGGMLType::Q2_K => bindings::GGML_TYPE_Q2_K, - LlamacppGGMLType::Q3_K => bindings::GGML_TYPE_Q3_K, - LlamacppGGMLType::Q4_K => bindings::GGML_TYPE_Q4_K, - LlamacppGGMLType::Q5_K => bindings::GGML_TYPE_Q5_K, - LlamacppGGMLType::Q6_K => bindings::GGML_TYPE_Q6_K, - LlamacppGGMLType::Q8_K => bindings::GGML_TYPE_Q8_K, - LlamacppGGMLType::IQ2_XXS => bindings::GGML_TYPE_IQ2_XXS, - LlamacppGGMLType::IQ2_XS => bindings::GGML_TYPE_IQ2_XS, - LlamacppGGMLType::IQ3_XXS => bindings::GGML_TYPE_IQ3_XXS, - LlamacppGGMLType::IQ1_S => bindings::GGML_TYPE_IQ1_S, - LlamacppGGMLType::IQ4_NL => bindings::GGML_TYPE_IQ4_NL, - LlamacppGGMLType::IQ3_S => bindings::GGML_TYPE_IQ3_S, - LlamacppGGMLType::IQ2_S => bindings::GGML_TYPE_IQ2_S, - LlamacppGGMLType::IQ4_XS => bindings::GGML_TYPE_IQ4_XS, - LlamacppGGMLType::I8 => bindings::GGML_TYPE_I8, - LlamacppGGMLType::I16 => bindings::GGML_TYPE_I16, - LlamacppGGMLType::I32 => bindings::GGML_TYPE_I32, - LlamacppGGMLType::I64 => bindings::GGML_TYPE_I64, - LlamacppGGMLType::F64 => bindings::GGML_TYPE_F64, - LlamacppGGMLType::IQ1_M => bindings::GGML_TYPE_IQ1_M, - LlamacppGGMLType::BF16 => bindings::GGML_TYPE_BF16, - LlamacppGGMLType::TQ1_0 => bindings::GGML_TYPE_TQ1_0, - LlamacppGGMLType::TQ2_0 => bindings::GGML_TYPE_TQ2_0, + LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32, + LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16, + LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0, + LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1, + LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0, + LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1, + LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0, + LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1, + LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K, + LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K, + LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K, + LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K, + LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K, + LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K, + LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS, + LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS, + LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS, + LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S, + LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL, + LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S, + LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S, + LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS, + LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8, + LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16, + LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32, + LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64, + LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64, + LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M, + LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16, + LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0, + LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0, } } } @@ -201,16 +201,16 @@ impl LlamacppRequest { } struct Llamacpp { - model: *mut bindings::llama_model, - ctx: *mut bindings::llama_context, - vocab: *const bindings::llama_vocab, - logprobs: Vec, - batch: bindings::llama_batch, + model: *mut llamacpp::llama_model, + ctx: *mut llamacpp::llama_context, + vocab: *const llamacpp::llama_vocab, + logprobs: Vec, + batch: llamacpp::llama_batch, n_ctx: u32, } extern "C" fn llamacpp_log_callback( - level: bindings::ggml_log_level, + level: llamacpp::ggml_log_level, msg: *const std::os::raw::c_char, _user_data: *mut std::os::raw::c_void, ) { @@ -218,10 +218,10 @@ extern "C" fn llamacpp_log_callback( let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string(); match level { - bindings::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg), - bindings::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg), - bindings::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg), - bindings::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg), + llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg), + llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg), + llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg), + llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg), _ => trace!(target: "llamacpp", "{}", rmsg), } } @@ -231,12 +231,12 @@ impl Llamacpp { let gguf = CString::new(conf.model_gguf)?; let model = unsafe { - let mut params = bindings::llama_model_default_params(); + let mut params = llamacpp::model_default_params(); params.n_gpu_layers = conf.n_gpu_layers as _; params.split_mode = match conf.split_mode { - LlamacppSplitMode::GPU(_) => bindings::LLAMA_SPLIT_MODE_NONE, - LlamacppSplitMode::Layer => bindings::LLAMA_SPLIT_MODE_LAYER, - LlamacppSplitMode::Row => bindings::LLAMA_SPLIT_MODE_ROW, + LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE, + LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER, + LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW, }; params.main_gpu = match conf.split_mode { LlamacppSplitMode::GPU(n) => n as _, @@ -244,13 +244,13 @@ impl Llamacpp { }; params.use_mmap = conf.use_mmap; params.use_mlock = conf.use_mlock; - bindings::llama_model_load_from_file(gguf.as_ptr(), params) + llamacpp::model_load_from_file(gguf.as_ptr(), params) }; if model.is_null() { return Err(BackendError::Llamacpp("Failed to load model".to_string())) } let ctx = unsafe { - let mut params = bindings::llama_context_default_params(); + let mut params = llamacpp::context_default_params(); params.n_ctx = conf.n_ctx as _; params.n_batch = conf.max_batch_total_tokens as _; params.n_ubatch = conf.max_physical_batch_total_tokens as _; @@ -263,48 +263,48 @@ impl Llamacpp { params.type_k = conf.type_k.to_ggml_type(); params.type_v = conf.type_v.to_ggml_type(); params.no_perf = true; - bindings::llama_init_from_model(model, params) + llamacpp::init_from_model(model, params) }; if ctx.is_null() { return Err(BackendError::Llamacpp("Failed to init context".to_string())) } - let n_ctx = unsafe { bindings::llama_n_ctx(ctx) }; + let n_ctx = unsafe { llamacpp::n_ctx(ctx) }; let vocab = unsafe { - bindings::llama_model_get_vocab(model) + llamacpp::model_get_vocab(model) }; if vocab.is_null() { return Err(BackendError::Llamacpp("Failed to get vocab".to_string())); } let n_tokens = unsafe { - bindings::llama_vocab_n_tokens(vocab) + llamacpp::vocab_n_tokens(vocab) }; let mut logprobs = Vec::with_capacity(n_tokens as usize); for token in 0..n_tokens { - logprobs.push(bindings::llama_token_data { + logprobs.push(llamacpp::llama_token_data { id: token, logit: 0.0, p: 0.0, }); } let batch = unsafe { - bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1) + llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) }; Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch}) } - fn clear_kv_cache(&mut self, seq_id: bindings::llama_seq_id) { + fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) { unsafe { - bindings::llama_kv_cache_seq_rm(self.ctx, seq_id, -1, -1); + llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1); } } fn batch_push( &mut self, - token: bindings::llama_token, - pos: bindings::llama_pos, - seq_id: bindings::llama_seq_id, + token: llamacpp::llama_token, + pos: llamacpp::llama_pos, + seq_id: llamacpp::llama_seq_id, logits: bool, ) -> usize { let n = self.batch.n_tokens as usize; @@ -323,43 +323,43 @@ impl Llamacpp { impl Drop for Llamacpp { fn drop(&mut self) { if !self.ctx.is_null() { - unsafe { bindings::llama_free(self.ctx) }; + unsafe { llamacpp::free(self.ctx) }; } if !self.model.is_null() { - unsafe { bindings::llama_model_free(self.model) }; + unsafe { llamacpp::model_free(self.model) }; } - unsafe { bindings::llama_batch_free(self.batch) }; + unsafe { llamacpp::batch_free(self.batch) }; } } struct LlamacppSampler { - chain: *mut bindings::llama_sampler, + chain: *mut llamacpp::llama_sampler, } impl LlamacppSampler { fn new(req: &LlamacppRequest) -> Option { let chain = unsafe { - let params = bindings::llama_sampler_chain_default_params(); - bindings::llama_sampler_chain_init(params) + let params = llamacpp::sampler_chain_default_params(); + llamacpp::sampler_chain_init(params) }; if chain.is_null() { error!("Failed to init sampler"); return None; } let top_k = unsafe { - bindings::llama_sampler_init_top_k(req.top_k) + llamacpp::sampler_init_top_k(req.top_k) }; let top_p = unsafe { - bindings::llama_sampler_init_top_p(req.top_p, req.min_keep) + llamacpp::sampler_init_top_p(req.top_p, req.min_keep) }; let typical_p = unsafe { - bindings::llama_sampler_init_typical(req.typical_p, req.min_keep) + llamacpp::sampler_init_typical(req.typical_p, req.min_keep) }; let temp = unsafe { - bindings::llama_sampler_init_temp(req.temp) + llamacpp::sampler_init_temp(req.temp) }; let penalties = unsafe { - bindings::llama_sampler_init_penalties( + llamacpp::sampler_init_penalties( req.penalty_last_n, req.penalty_repeat, req.penalty_freq, @@ -367,7 +367,7 @@ impl LlamacppSampler { ) }; let dist = unsafe { - bindings::llama_sampler_init_dist(req.seed) + llamacpp::sampler_init_dist(req.seed) }; let mut failed = false; @@ -381,7 +381,7 @@ impl LlamacppSampler { error!("Failed to init {k} sampler"); failed = true; } else { - unsafe { bindings::llama_sampler_chain_add(chain, *v) }; + unsafe { llamacpp::sampler_chain_add(chain, *v) }; } } if failed { @@ -391,27 +391,27 @@ impl LlamacppSampler { } } - fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) { + fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) { let logits = unsafe { - bindings::llama_get_logits_ith(llamacpp.ctx, idx as _) + llamacpp::get_logits_ith(llamacpp.ctx, idx as _) }; for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() { - *logprob = bindings::llama_token_data { + *logprob = llamacpp::llama_token_data { id: token as _, logit: unsafe { *logits.offset(token as _) }, p: 0.0, }; } - let mut view = bindings::llama_token_data_array { + let mut view = llamacpp::llama_token_data_array { data: llamacpp.logprobs.as_mut_ptr(), size: llamacpp.logprobs.len(), selected: -1, sorted: false, }; unsafe { - bindings::llama_sampler_apply(self.chain, &mut view); + llamacpp::sampler_apply(self.chain, &mut view); let logprob = *view.data.offset(view.selected as _); - bindings::llama_sampler_accept(self.chain, logprob.id); + llamacpp::sampler_accept(self.chain, logprob.id); (logprob.id, logprob.p.ln()) } } @@ -420,7 +420,7 @@ impl LlamacppSampler { impl Drop for LlamacppSampler { fn drop(&mut self) { if !self.chain.is_null() { - unsafe { bindings::llama_sampler_free(self.chain) }; + unsafe { llamacpp::sampler_free(self.chain) }; } } } @@ -428,8 +428,8 @@ impl Drop for LlamacppSampler { struct LlamacppSeq { id: usize, batch_pos: usize, - token: bindings::llama_token, - pos: bindings::llama_pos, + token: llamacpp::llama_token, + pos: llamacpp::llama_pos, sampler: LlamacppSampler, text: String, n_new_tokens: usize, @@ -446,14 +446,14 @@ impl LlamacppBackend { // Setup llama & export logs, once and for all INIT.call_once(|| unsafe { - bindings::llama_log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); - bindings::llama_backend_init(); - bindings::llama_numa_init(match conf.numa { - LlamacppNuma::Disabled => bindings::GGML_NUMA_STRATEGY_DISABLED, - LlamacppNuma::Distribute => bindings::GGML_NUMA_STRATEGY_DISTRIBUTE, - LlamacppNuma::Isolate => bindings::GGML_NUMA_STRATEGY_ISOLATE, - LlamacppNuma::Numactl => bindings::GGML_NUMA_STRATEGY_NUMACTL, - LlamacppNuma::Mirror => bindings::GGML_NUMA_STRATEGY_MIRROR, + llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); + llamacpp::backend_init(); + llamacpp::numa_init(match conf.numa { + LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED, + LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE, + LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE, + LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL, + LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR, }); }); @@ -526,17 +526,17 @@ impl LlamacppBackend { for (pos, &token_id) in request.input_ids.iter().enumerate() { llamacpp.batch_push( - token_id as bindings::llama_token, - pos as bindings::llama_pos, - seq_id as bindings::llama_seq_id, + token_id as llamacpp::llama_token, + pos as llamacpp::llama_pos, + seq_id as llamacpp::llama_seq_id, pos == last_pos, // check samplers ); } seqs.push(LlamacppSeq { id: seq_id, batch_pos: llamacpp.batch.n_tokens as usize - 1, - token: bindings::LLAMA_TOKEN_NULL, - pos: last_pos as bindings::llama_pos + 1, + token: llamacpp::LLAMA_TOKEN_NULL, + pos: last_pos as llamacpp::llama_pos + 1, sampler: sampler, text: String::with_capacity(1024), n_new_tokens: 0, @@ -548,7 +548,7 @@ impl LlamacppBackend { break; } let decode = unsafe { - bindings::llama_decode(llamacpp.ctx, llamacpp.batch) + llamacpp::decode(llamacpp.ctx, llamacpp.batch) }; if decode != 0 { warn!("llama_decode failed, clearing kv cache"); @@ -560,7 +560,7 @@ impl LlamacppBackend { break; } let kv_cache_used_cells = unsafe { - bindings::llama_get_kv_cache_used_cells(llamacpp.ctx) + llamacpp::get_kv_cache_used_cells(llamacpp.ctx) }; for seq in seqs.iter_mut() { if !seq.running { @@ -591,7 +591,7 @@ impl LlamacppBackend { special: special, }; let finish: Option = { - if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } { + if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } { Some(FinishReason::EndOfSequenceToken) } else if seq.n_new_tokens == requests[seq.id].max_new_tokens { Some(FinishReason::Length)