mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Rename bindings
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
c52f08351f
commit
051ff2d5ce
@ -1,3 +1,5 @@
|
|||||||
|
|
||||||
|
use bindgen::callbacks::{ParseCallbacks, ItemInfo};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@ -20,6 +22,15 @@ fn inject_transient_dependencies(lib_search_path: Option<&str>, lib_target_hardw
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PrefixStripper;
|
||||||
|
|
||||||
|
impl ParseCallbacks for PrefixStripper {
|
||||||
|
fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
|
||||||
|
item_info.name.strip_prefix("llama_").map(str::to_string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let pkg_cuda = option_env!("TGI_LLAMA_PKG_CUDA");
|
let pkg_cuda = option_env!("TGI_LLAMA_PKG_CUDA");
|
||||||
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
|
let lib_search_path = option_env!("TGI_LLAMA_LD_LIBRARY_PATH");
|
||||||
@ -28,13 +39,14 @@ fn main() {
|
|||||||
let bindings = bindgen::Builder::default()
|
let bindings = bindgen::Builder::default()
|
||||||
.header("src/wrapper.h")
|
.header("src/wrapper.h")
|
||||||
.prepend_enum_name(false)
|
.prepend_enum_name(false)
|
||||||
|
.parse_callbacks(Box::new(PrefixStripper))
|
||||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
||||||
.generate()
|
.generate()
|
||||||
.expect("Unable to generate bindings");
|
.expect("Unable to generate bindings");
|
||||||
|
|
||||||
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
|
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
bindings
|
bindings
|
||||||
.write_to_file(out_path.join("bindings.rs"))
|
.write_to_file(out_path.join("llamacpp.rs"))
|
||||||
.expect("Couldn't write bindings!");
|
.expect("Couldn't write bindings!");
|
||||||
|
|
||||||
if let Some(pkg_cuda) = pkg_cuda {
|
if let Some(pkg_cuda) = pkg_cuda {
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
mod bindings {
|
mod llamacpp {
|
||||||
#![allow(non_upper_case_globals)]
|
#![allow(non_upper_case_globals)]
|
||||||
#![allow(non_camel_case_types)]
|
#![allow(non_camel_case_types)]
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
|
||||||
}
|
}
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
@ -91,39 +91,39 @@ pub enum LlamacppGGMLType {
|
|||||||
|
|
||||||
// TODO: macro
|
// TODO: macro
|
||||||
impl LlamacppGGMLType {
|
impl LlamacppGGMLType {
|
||||||
fn to_ggml_type(&self) -> bindings::ggml_type {
|
fn to_ggml_type(&self) -> llamacpp::ggml_type {
|
||||||
match self {
|
match self {
|
||||||
LlamacppGGMLType::F32 => bindings::GGML_TYPE_F32,
|
LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
|
||||||
LlamacppGGMLType::F16 => bindings::GGML_TYPE_F16,
|
LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
|
||||||
LlamacppGGMLType::Q4_0 => bindings::GGML_TYPE_Q4_0,
|
LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
|
||||||
LlamacppGGMLType::Q4_1 => bindings::GGML_TYPE_Q4_1,
|
LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
|
||||||
LlamacppGGMLType::Q5_0 => bindings::GGML_TYPE_Q5_0,
|
LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
|
||||||
LlamacppGGMLType::Q5_1 => bindings::GGML_TYPE_Q5_1,
|
LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
|
||||||
LlamacppGGMLType::Q8_0 => bindings::GGML_TYPE_Q8_0,
|
LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
|
||||||
LlamacppGGMLType::Q8_1 => bindings::GGML_TYPE_Q8_1,
|
LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
|
||||||
LlamacppGGMLType::Q2_K => bindings::GGML_TYPE_Q2_K,
|
LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
|
||||||
LlamacppGGMLType::Q3_K => bindings::GGML_TYPE_Q3_K,
|
LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
|
||||||
LlamacppGGMLType::Q4_K => bindings::GGML_TYPE_Q4_K,
|
LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
|
||||||
LlamacppGGMLType::Q5_K => bindings::GGML_TYPE_Q5_K,
|
LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
|
||||||
LlamacppGGMLType::Q6_K => bindings::GGML_TYPE_Q6_K,
|
LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
|
||||||
LlamacppGGMLType::Q8_K => bindings::GGML_TYPE_Q8_K,
|
LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
|
||||||
LlamacppGGMLType::IQ2_XXS => bindings::GGML_TYPE_IQ2_XXS,
|
LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
|
||||||
LlamacppGGMLType::IQ2_XS => bindings::GGML_TYPE_IQ2_XS,
|
LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
|
||||||
LlamacppGGMLType::IQ3_XXS => bindings::GGML_TYPE_IQ3_XXS,
|
LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
|
||||||
LlamacppGGMLType::IQ1_S => bindings::GGML_TYPE_IQ1_S,
|
LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
|
||||||
LlamacppGGMLType::IQ4_NL => bindings::GGML_TYPE_IQ4_NL,
|
LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
|
||||||
LlamacppGGMLType::IQ3_S => bindings::GGML_TYPE_IQ3_S,
|
LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
|
||||||
LlamacppGGMLType::IQ2_S => bindings::GGML_TYPE_IQ2_S,
|
LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
|
||||||
LlamacppGGMLType::IQ4_XS => bindings::GGML_TYPE_IQ4_XS,
|
LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
|
||||||
LlamacppGGMLType::I8 => bindings::GGML_TYPE_I8,
|
LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
|
||||||
LlamacppGGMLType::I16 => bindings::GGML_TYPE_I16,
|
LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
|
||||||
LlamacppGGMLType::I32 => bindings::GGML_TYPE_I32,
|
LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
|
||||||
LlamacppGGMLType::I64 => bindings::GGML_TYPE_I64,
|
LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
|
||||||
LlamacppGGMLType::F64 => bindings::GGML_TYPE_F64,
|
LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
|
||||||
LlamacppGGMLType::IQ1_M => bindings::GGML_TYPE_IQ1_M,
|
LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
|
||||||
LlamacppGGMLType::BF16 => bindings::GGML_TYPE_BF16,
|
LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
|
||||||
LlamacppGGMLType::TQ1_0 => bindings::GGML_TYPE_TQ1_0,
|
LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
|
||||||
LlamacppGGMLType::TQ2_0 => bindings::GGML_TYPE_TQ2_0,
|
LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -201,16 +201,16 @@ impl LlamacppRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Llamacpp {
|
struct Llamacpp {
|
||||||
model: *mut bindings::llama_model,
|
model: *mut llamacpp::llama_model,
|
||||||
ctx: *mut bindings::llama_context,
|
ctx: *mut llamacpp::llama_context,
|
||||||
vocab: *const bindings::llama_vocab,
|
vocab: *const llamacpp::llama_vocab,
|
||||||
logprobs: Vec<bindings::llama_token_data>,
|
logprobs: Vec<llamacpp::llama_token_data>,
|
||||||
batch: bindings::llama_batch,
|
batch: llamacpp::llama_batch,
|
||||||
n_ctx: u32,
|
n_ctx: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" fn llamacpp_log_callback(
|
extern "C" fn llamacpp_log_callback(
|
||||||
level: bindings::ggml_log_level,
|
level: llamacpp::ggml_log_level,
|
||||||
msg: *const std::os::raw::c_char,
|
msg: *const std::os::raw::c_char,
|
||||||
_user_data: *mut std::os::raw::c_void,
|
_user_data: *mut std::os::raw::c_void,
|
||||||
) {
|
) {
|
||||||
@ -218,10 +218,10 @@ extern "C" fn llamacpp_log_callback(
|
|||||||
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
|
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
|
||||||
|
|
||||||
match level {
|
match level {
|
||||||
bindings::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
|
llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
|
||||||
bindings::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
|
llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
|
||||||
bindings::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
|
llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
|
||||||
bindings::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
|
llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
|
||||||
_ => trace!(target: "llamacpp", "{}", rmsg),
|
_ => trace!(target: "llamacpp", "{}", rmsg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,12 +231,12 @@ impl Llamacpp {
|
|||||||
let gguf = CString::new(conf.model_gguf)?;
|
let gguf = CString::new(conf.model_gguf)?;
|
||||||
|
|
||||||
let model = unsafe {
|
let model = unsafe {
|
||||||
let mut params = bindings::llama_model_default_params();
|
let mut params = llamacpp::model_default_params();
|
||||||
params.n_gpu_layers = conf.n_gpu_layers as _;
|
params.n_gpu_layers = conf.n_gpu_layers as _;
|
||||||
params.split_mode = match conf.split_mode {
|
params.split_mode = match conf.split_mode {
|
||||||
LlamacppSplitMode::GPU(_) => bindings::LLAMA_SPLIT_MODE_NONE,
|
LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
|
||||||
LlamacppSplitMode::Layer => bindings::LLAMA_SPLIT_MODE_LAYER,
|
LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
|
||||||
LlamacppSplitMode::Row => bindings::LLAMA_SPLIT_MODE_ROW,
|
LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
|
||||||
};
|
};
|
||||||
params.main_gpu = match conf.split_mode {
|
params.main_gpu = match conf.split_mode {
|
||||||
LlamacppSplitMode::GPU(n) => n as _,
|
LlamacppSplitMode::GPU(n) => n as _,
|
||||||
@ -244,13 +244,13 @@ impl Llamacpp {
|
|||||||
};
|
};
|
||||||
params.use_mmap = conf.use_mmap;
|
params.use_mmap = conf.use_mmap;
|
||||||
params.use_mlock = conf.use_mlock;
|
params.use_mlock = conf.use_mlock;
|
||||||
bindings::llama_model_load_from_file(gguf.as_ptr(), params)
|
llamacpp::model_load_from_file(gguf.as_ptr(), params)
|
||||||
};
|
};
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
|
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
|
||||||
}
|
}
|
||||||
let ctx = unsafe {
|
let ctx = unsafe {
|
||||||
let mut params = bindings::llama_context_default_params();
|
let mut params = llamacpp::context_default_params();
|
||||||
params.n_ctx = conf.n_ctx as _;
|
params.n_ctx = conf.n_ctx as _;
|
||||||
params.n_batch = conf.max_batch_total_tokens as _;
|
params.n_batch = conf.max_batch_total_tokens as _;
|
||||||
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
|
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
|
||||||
@ -263,48 +263,48 @@ impl Llamacpp {
|
|||||||
params.type_k = conf.type_k.to_ggml_type();
|
params.type_k = conf.type_k.to_ggml_type();
|
||||||
params.type_v = conf.type_v.to_ggml_type();
|
params.type_v = conf.type_v.to_ggml_type();
|
||||||
params.no_perf = true;
|
params.no_perf = true;
|
||||||
bindings::llama_init_from_model(model, params)
|
llamacpp::init_from_model(model, params)
|
||||||
};
|
};
|
||||||
if ctx.is_null() {
|
if ctx.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
||||||
}
|
}
|
||||||
let n_ctx = unsafe { bindings::llama_n_ctx(ctx) };
|
let n_ctx = unsafe { llamacpp::n_ctx(ctx) };
|
||||||
|
|
||||||
let vocab = unsafe {
|
let vocab = unsafe {
|
||||||
bindings::llama_model_get_vocab(model)
|
llamacpp::model_get_vocab(model)
|
||||||
};
|
};
|
||||||
if vocab.is_null() {
|
if vocab.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||||
}
|
}
|
||||||
let n_tokens = unsafe {
|
let n_tokens = unsafe {
|
||||||
bindings::llama_vocab_n_tokens(vocab)
|
llamacpp::vocab_n_tokens(vocab)
|
||||||
};
|
};
|
||||||
let mut logprobs = Vec::with_capacity(n_tokens as usize);
|
let mut logprobs = Vec::with_capacity(n_tokens as usize);
|
||||||
|
|
||||||
for token in 0..n_tokens {
|
for token in 0..n_tokens {
|
||||||
logprobs.push(bindings::llama_token_data {
|
logprobs.push(llamacpp::llama_token_data {
|
||||||
id: token,
|
id: token,
|
||||||
logit: 0.0,
|
logit: 0.0,
|
||||||
p: 0.0,
|
p: 0.0,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let batch = unsafe {
|
let batch = unsafe {
|
||||||
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
||||||
};
|
};
|
||||||
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
Ok(Llamacpp{model, ctx, vocab, logprobs, n_ctx, batch})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self, seq_id: bindings::llama_seq_id) {
|
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
|
||||||
unsafe {
|
unsafe {
|
||||||
bindings::llama_kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
|
llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn batch_push(
|
fn batch_push(
|
||||||
&mut self,
|
&mut self,
|
||||||
token: bindings::llama_token,
|
token: llamacpp::llama_token,
|
||||||
pos: bindings::llama_pos,
|
pos: llamacpp::llama_pos,
|
||||||
seq_id: bindings::llama_seq_id,
|
seq_id: llamacpp::llama_seq_id,
|
||||||
logits: bool,
|
logits: bool,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
let n = self.batch.n_tokens as usize;
|
let n = self.batch.n_tokens as usize;
|
||||||
@ -323,43 +323,43 @@ impl Llamacpp {
|
|||||||
impl Drop for Llamacpp {
|
impl Drop for Llamacpp {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if !self.ctx.is_null() {
|
if !self.ctx.is_null() {
|
||||||
unsafe { bindings::llama_free(self.ctx) };
|
unsafe { llamacpp::free(self.ctx) };
|
||||||
}
|
}
|
||||||
if !self.model.is_null() {
|
if !self.model.is_null() {
|
||||||
unsafe { bindings::llama_model_free(self.model) };
|
unsafe { llamacpp::model_free(self.model) };
|
||||||
}
|
}
|
||||||
unsafe { bindings::llama_batch_free(self.batch) };
|
unsafe { llamacpp::batch_free(self.batch) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LlamacppSampler {
|
struct LlamacppSampler {
|
||||||
chain: *mut bindings::llama_sampler,
|
chain: *mut llamacpp::llama_sampler,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamacppSampler {
|
impl LlamacppSampler {
|
||||||
fn new(req: &LlamacppRequest) -> Option<Self> {
|
fn new(req: &LlamacppRequest) -> Option<Self> {
|
||||||
let chain = unsafe {
|
let chain = unsafe {
|
||||||
let params = bindings::llama_sampler_chain_default_params();
|
let params = llamacpp::sampler_chain_default_params();
|
||||||
bindings::llama_sampler_chain_init(params)
|
llamacpp::sampler_chain_init(params)
|
||||||
};
|
};
|
||||||
if chain.is_null() {
|
if chain.is_null() {
|
||||||
error!("Failed to init sampler");
|
error!("Failed to init sampler");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let top_k = unsafe {
|
let top_k = unsafe {
|
||||||
bindings::llama_sampler_init_top_k(req.top_k)
|
llamacpp::sampler_init_top_k(req.top_k)
|
||||||
};
|
};
|
||||||
let top_p = unsafe {
|
let top_p = unsafe {
|
||||||
bindings::llama_sampler_init_top_p(req.top_p, req.min_keep)
|
llamacpp::sampler_init_top_p(req.top_p, req.min_keep)
|
||||||
};
|
};
|
||||||
let typical_p = unsafe {
|
let typical_p = unsafe {
|
||||||
bindings::llama_sampler_init_typical(req.typical_p, req.min_keep)
|
llamacpp::sampler_init_typical(req.typical_p, req.min_keep)
|
||||||
};
|
};
|
||||||
let temp = unsafe {
|
let temp = unsafe {
|
||||||
bindings::llama_sampler_init_temp(req.temp)
|
llamacpp::sampler_init_temp(req.temp)
|
||||||
};
|
};
|
||||||
let penalties = unsafe {
|
let penalties = unsafe {
|
||||||
bindings::llama_sampler_init_penalties(
|
llamacpp::sampler_init_penalties(
|
||||||
req.penalty_last_n,
|
req.penalty_last_n,
|
||||||
req.penalty_repeat,
|
req.penalty_repeat,
|
||||||
req.penalty_freq,
|
req.penalty_freq,
|
||||||
@ -367,7 +367,7 @@ impl LlamacppSampler {
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
let dist = unsafe {
|
let dist = unsafe {
|
||||||
bindings::llama_sampler_init_dist(req.seed)
|
llamacpp::sampler_init_dist(req.seed)
|
||||||
};
|
};
|
||||||
let mut failed = false;
|
let mut failed = false;
|
||||||
|
|
||||||
@ -381,7 +381,7 @@ impl LlamacppSampler {
|
|||||||
error!("Failed to init {k} sampler");
|
error!("Failed to init {k} sampler");
|
||||||
failed = true;
|
failed = true;
|
||||||
} else {
|
} else {
|
||||||
unsafe { bindings::llama_sampler_chain_add(chain, *v) };
|
unsafe { llamacpp::sampler_chain_add(chain, *v) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if failed {
|
if failed {
|
||||||
@ -391,27 +391,27 @@ impl LlamacppSampler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (bindings::llama_token, f32) {
|
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
|
||||||
let logits = unsafe {
|
let logits = unsafe {
|
||||||
bindings::llama_get_logits_ith(llamacpp.ctx, idx as _)
|
llamacpp::get_logits_ith(llamacpp.ctx, idx as _)
|
||||||
};
|
};
|
||||||
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
|
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
|
||||||
*logprob = bindings::llama_token_data {
|
*logprob = llamacpp::llama_token_data {
|
||||||
id: token as _,
|
id: token as _,
|
||||||
logit: unsafe { *logits.offset(token as _) },
|
logit: unsafe { *logits.offset(token as _) },
|
||||||
p: 0.0,
|
p: 0.0,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
let mut view = bindings::llama_token_data_array {
|
let mut view = llamacpp::llama_token_data_array {
|
||||||
data: llamacpp.logprobs.as_mut_ptr(),
|
data: llamacpp.logprobs.as_mut_ptr(),
|
||||||
size: llamacpp.logprobs.len(),
|
size: llamacpp.logprobs.len(),
|
||||||
selected: -1,
|
selected: -1,
|
||||||
sorted: false,
|
sorted: false,
|
||||||
};
|
};
|
||||||
unsafe {
|
unsafe {
|
||||||
bindings::llama_sampler_apply(self.chain, &mut view);
|
llamacpp::sampler_apply(self.chain, &mut view);
|
||||||
let logprob = *view.data.offset(view.selected as _);
|
let logprob = *view.data.offset(view.selected as _);
|
||||||
bindings::llama_sampler_accept(self.chain, logprob.id);
|
llamacpp::sampler_accept(self.chain, logprob.id);
|
||||||
(logprob.id, logprob.p.ln())
|
(logprob.id, logprob.p.ln())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -420,7 +420,7 @@ impl LlamacppSampler {
|
|||||||
impl Drop for LlamacppSampler {
|
impl Drop for LlamacppSampler {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if !self.chain.is_null() {
|
if !self.chain.is_null() {
|
||||||
unsafe { bindings::llama_sampler_free(self.chain) };
|
unsafe { llamacpp::sampler_free(self.chain) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -428,8 +428,8 @@ impl Drop for LlamacppSampler {
|
|||||||
struct LlamacppSeq {
|
struct LlamacppSeq {
|
||||||
id: usize,
|
id: usize,
|
||||||
batch_pos: usize,
|
batch_pos: usize,
|
||||||
token: bindings::llama_token,
|
token: llamacpp::llama_token,
|
||||||
pos: bindings::llama_pos,
|
pos: llamacpp::llama_pos,
|
||||||
sampler: LlamacppSampler,
|
sampler: LlamacppSampler,
|
||||||
text: String,
|
text: String,
|
||||||
n_new_tokens: usize,
|
n_new_tokens: usize,
|
||||||
@ -446,14 +446,14 @@ impl LlamacppBackend {
|
|||||||
|
|
||||||
// Setup llama & export logs, once and for all
|
// Setup llama & export logs, once and for all
|
||||||
INIT.call_once(|| unsafe {
|
INIT.call_once(|| unsafe {
|
||||||
bindings::llama_log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
|
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
|
||||||
bindings::llama_backend_init();
|
llamacpp::backend_init();
|
||||||
bindings::llama_numa_init(match conf.numa {
|
llamacpp::numa_init(match conf.numa {
|
||||||
LlamacppNuma::Disabled => bindings::GGML_NUMA_STRATEGY_DISABLED,
|
LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
|
||||||
LlamacppNuma::Distribute => bindings::GGML_NUMA_STRATEGY_DISTRIBUTE,
|
LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
|
||||||
LlamacppNuma::Isolate => bindings::GGML_NUMA_STRATEGY_ISOLATE,
|
LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
|
||||||
LlamacppNuma::Numactl => bindings::GGML_NUMA_STRATEGY_NUMACTL,
|
LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
|
||||||
LlamacppNuma::Mirror => bindings::GGML_NUMA_STRATEGY_MIRROR,
|
LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -526,17 +526,17 @@ impl LlamacppBackend {
|
|||||||
|
|
||||||
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
for (pos, &token_id) in request.input_ids.iter().enumerate() {
|
||||||
llamacpp.batch_push(
|
llamacpp.batch_push(
|
||||||
token_id as bindings::llama_token,
|
token_id as llamacpp::llama_token,
|
||||||
pos as bindings::llama_pos,
|
pos as llamacpp::llama_pos,
|
||||||
seq_id as bindings::llama_seq_id,
|
seq_id as llamacpp::llama_seq_id,
|
||||||
pos == last_pos, // check samplers
|
pos == last_pos, // check samplers
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
seqs.push(LlamacppSeq {
|
seqs.push(LlamacppSeq {
|
||||||
id: seq_id,
|
id: seq_id,
|
||||||
batch_pos: llamacpp.batch.n_tokens as usize - 1,
|
batch_pos: llamacpp.batch.n_tokens as usize - 1,
|
||||||
token: bindings::LLAMA_TOKEN_NULL,
|
token: llamacpp::LLAMA_TOKEN_NULL,
|
||||||
pos: last_pos as bindings::llama_pos + 1,
|
pos: last_pos as llamacpp::llama_pos + 1,
|
||||||
sampler: sampler,
|
sampler: sampler,
|
||||||
text: String::with_capacity(1024),
|
text: String::with_capacity(1024),
|
||||||
n_new_tokens: 0,
|
n_new_tokens: 0,
|
||||||
@ -548,7 +548,7 @@ impl LlamacppBackend {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let decode = unsafe {
|
let decode = unsafe {
|
||||||
bindings::llama_decode(llamacpp.ctx, llamacpp.batch)
|
llamacpp::decode(llamacpp.ctx, llamacpp.batch)
|
||||||
};
|
};
|
||||||
if decode != 0 {
|
if decode != 0 {
|
||||||
warn!("llama_decode failed, clearing kv cache");
|
warn!("llama_decode failed, clearing kv cache");
|
||||||
@ -560,7 +560,7 @@ impl LlamacppBackend {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let kv_cache_used_cells = unsafe {
|
let kv_cache_used_cells = unsafe {
|
||||||
bindings::llama_get_kv_cache_used_cells(llamacpp.ctx)
|
llamacpp::get_kv_cache_used_cells(llamacpp.ctx)
|
||||||
};
|
};
|
||||||
for seq in seqs.iter_mut() {
|
for seq in seqs.iter_mut() {
|
||||||
if !seq.running {
|
if !seq.running {
|
||||||
@ -591,7 +591,7 @@ impl LlamacppBackend {
|
|||||||
special: special,
|
special: special,
|
||||||
};
|
};
|
||||||
let finish: Option<FinishReason> = {
|
let finish: Option<FinishReason> = {
|
||||||
if unsafe { bindings::llama_vocab_is_eog(llamacpp.vocab, next) } {
|
if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
|
||||||
Some(FinishReason::EndOfSequenceToken)
|
Some(FinishReason::EndOfSequenceToken)
|
||||||
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
|
||||||
Some(FinishReason::Length)
|
Some(FinishReason::Length)
|
||||||
|
Loading…
Reference in New Issue
Block a user