text-generation-inference/backends/llamacpp/src/lib.rs

64 lines
1.6 KiB
Rust
Raw Normal View History

use crate::backend::InferContext;
use crate::ffi::SamplingParams;
2024-10-04 08:42:31 +00:00
pub mod backend;
impl Default for SamplingParams {
fn default() -> Self {
Self {
top_k: u32::MAX,
top_p: 1.0f32,
frequency_penalty: 0.0f32,
repetition_penalty: 0.0f32,
seed: 2014u64,
}
}
}
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
2024-10-04 08:42:31 +00:00
mod ffi {
#[derive(Debug, Copy, Clone)]
struct GenerationParams {
max_new_tokens: u32,
ignore_eos_token: bool,
}
#[derive(Debug, Copy, Clone)]
struct SamplingParams {
top_k: u32,
top_p: f32,
frequency_penalty: f32,
repetition_penalty: f32,
seed: u64,
}
extern "Rust" {
type InferContext<'a>;
}
2024-10-04 08:42:31 +00:00
unsafe extern "C++" {
2024-10-24 07:56:40 +00:00
include!("backends/llamacpp/csrc/ffi.hpp");
2024-10-04 08:42:31 +00:00
#[cxx_name = "generation_params_t"]
type GenerationParams;
#[cxx_name = "sampling_params_t"]
type SamplingParams;
2024-10-04 08:42:31 +00:00
/// Represent an instance of the llama.cpp backend instance on C++ side
#[cxx_name = "llama_cpp_worker_frontend_t"]
type LlamaCppWorkerFrontend;
2024-10-24 07:56:40 +00:00
fn create_worker_frontend(modelPath: &str) -> Result<UniquePtr<LlamaCppWorkerFrontend>>;
unsafe fn stream(
self: Pin<&mut LlamaCppWorkerFrontend>,
tokens: &[u32],
generation_params: GenerationParams,
sampling_params: &SamplingParams,
stream: *mut InferContext,
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize) -> bool,
) -> Result<usize>;
2024-10-04 08:42:31 +00:00
}
}