text-generation-inference/backends/llamacpp/src/lib.rs

58 lines
1.5 KiB
Rust
Raw Normal View History

use crate::ffi::SamplingParams;
2024-10-04 08:42:31 +00:00
pub mod backend;
impl Default for SamplingParams {
fn default() -> Self {
Self {
top_k: u32::MAX,
top_p: 1.0f32,
frequency_penalty: 0.0f32,
repetition_penalty: 0.0f32,
seed: 2014u64,
}
}
}
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
2024-10-04 08:42:31 +00:00
mod ffi {
struct GenerationParams {
max_new_tokens: u32,
ignore_eos_token: bool,
}
struct SamplingParams {
top_k: u32,
top_p: f32,
frequency_penalty: f32,
repetition_penalty: f32,
seed: u64,
}
2024-10-04 08:42:31 +00:00
unsafe extern "C++" {
2024-10-24 07:56:40 +00:00
include!("backends/llamacpp/csrc/ffi.hpp");
2024-10-04 08:42:31 +00:00
#[cxx_name = "generation_params_t"]
type GenerationParams;
#[cxx_name = "sampling_params_t"]
type SamplingParams;
2024-10-04 08:42:31 +00:00
/// Represent an instance of the llama.cpp backend instance on C++ side
#[cxx_name = "llama_cpp_backend_impl_t"]
2024-10-04 08:42:31 +00:00
type LlamaCppBackendImpl;
2024-10-24 07:56:40 +00:00
#[rust_name = "create_single_worker_backend"]
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
fn generate(
self: Pin<&mut LlamaCppBackendImpl>,
tokens: &[u32],
generated: &mut [u32],
generation_params: &GenerationParams,
sampling_params: &SamplingParams,
callback: fn(u32, f32, bool),
) -> Result<usize>;
2024-10-04 08:42:31 +00:00
}
}