diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index daf8de54..f2f5d4c6 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -21,7 +21,7 @@ namespace huggingface::tgi::backends::llamacpp { batch.token[i] = input_tokens[i]; batch.pos[i] = i; batch.n_seq_id[i] = 1; - batch.seq_id[i] = 0; + batch.seq_id[i] = nullptr; batch.logits[i] = false; ++batch.n_tokens; } @@ -84,13 +84,12 @@ namespace huggingface::tgi::backends::llamacpp { const generation_context_t &generation_context, const std::optional &callback) const { // Store information about context and generation size - auto prompt_length = std::ssize(generation_context.input_tokens); auto max_new_tokens = generation_context.generation_params.max_new_tokens; // Convert sampling params to what llama.cpp is looking for auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get()); - // Setup the prompt + // Set up the prompt auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); auto batch = llama_batch_get_one(copy.data(), copy.size()); @@ -168,4 +167,15 @@ namespace huggingface::tgi::backends::llamacpp { ) { return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback); } + + std::expected + multi_worker_backend_t::generate( + std::span, + std::span, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback) { + SPDLOG_ERROR("Not implemented yet"); + return 0uz; + } } \ No newline at end of file diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index e7545a3c..871490f2 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -180,8 +180,20 @@ namespace huggingface::tgi::backends::llamacpp { const sampling_params_t &sampling_params, const std::optional &callback ) override; + }; + class multi_worker_backend_t : backend_base_t { + private: + llama_context_smart_ptr mContext_; + public: + std::expected generate( + std::span, + std::span, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const std::optional &callback + ) override; }; } diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index d15728b9..18254114 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -12,36 +12,92 @@ #include #include "backend.hpp" -namespace huggingface::tgi::backends::llamacpp::impl { - class LlamaCppBackendImpl; +namespace huggingface::tgi::backends::llamacpp { + struct generation_params_t; + struct sampling_params_t; + + class llama_cpp_backend_impl_t; } #include "backends/llamacpp/src/lib.rs.h" -namespace huggingface::tgi::backends::llamacpp::impl { +namespace huggingface::tgi::backends::llamacpp { - class LlamaCppBackendException : std::exception { + // Concept identifying types which have a .generate() -> size_t method to do in-place generation + template + concept has_emplace_generate = requires( + T t, + std::span input_tokens, + std::span generated_tokens, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + llama_decode_callback callback + ) { + { + t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback) + } -> std::same_as>; + }; + + static_assert(has_emplace_generate, + "single_worker_backend_t doesn't meet concept is_generate_emplace_capable"); + static_assert(has_emplace_generate, + "multi_worker_backend_t doesn't meet concept is_generate_emplace_capable"); + + class llama_cpp_backend_exception_t : std::exception { }; - class LlamaCppBackendImpl { + /** + * Llama.cpp backend interfacing with Rust FFI layer + */ + class llama_cpp_backend_impl_t { private: - BackendBase _inner; + std::variant mInner_; public: - LlamaCppBackendImpl(llama_model *model) : _inner(model) {} + explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {} + + explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {} + + size_t generate( + rust::Slice input_tokens, + rust::Slice generated_tokens, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + rust::Fn callback + ) { + // Define the visitor lambda function which requires the has_emplace_generate constraint on T + static auto inner_fw = [=, &generation_params, &sampling_params](T &&backend) + -> std::expected { + + // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* + auto input_tokens_v = + std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); + auto generated_tokens_v = + std::span(reinterpret_cast(generated_tokens.data()), generated_tokens.size()); + + return backend.generate( + input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback); + }; + + if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { + return *result; + } else { + throw llama_cpp_backend_exception_t(); + } + } }; - std::unique_ptr CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) { - const auto cxxPath = std::string_view(modelPath); - if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath), nThreads); maybe.has_value()) { - auto [model, context] = *maybe; - return std::make_unique(model, context); - } else { - throw LlamaCppBackendException(); - } + std::unique_ptr create_single_worker_backend(rust::Str modelPath) { + const auto cxxPath = std::string(modelPath); + auto params = llama_model_default_params(); + params.use_mmap = true; + + auto *model = llama_load_model_from_file(cxxPath.c_str(), params); + auto backend = single_worker_backend_t(model, std::nullopt); + return std::make_unique(std::move(backend)); } } diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index af50470d..6e9e8d2d 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,7 +1,8 @@ -use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl}; +use crate::ffi::{ + create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, +}; use async_trait::async_trait; use cxx::{Exception, UniquePtr}; -use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::thread::spawn; @@ -25,10 +26,7 @@ pub enum LlamaCppBackendError { pub struct LlamaCppBackend {} impl LlamaCppBackend { - pub fn new + Send>( - model_path: P, - n_threads: u16, - ) -> Result { + pub fn new + Send>(model_path: P) -> Result { let path = Arc::new(model_path.as_ref()); if !path.exists() { return Err(LlamaCppBackendError::ModelFileDoesntExist( @@ -36,13 +34,12 @@ impl LlamaCppBackend { )); } - let mut backend = - create_llamacpp_backend(path.to_str().unwrap(), n_threads).map_err(|err| { - LlamaCppBackendError::ModelInitializationFailed( - path.to_path_buf(), - err.what().to_string(), - ) - })?; + let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { + LlamaCppBackendError::ModelInitializationFailed( + path.to_path_buf(), + err.what().to_string(), + ) + })?; info!( "Successfully initialized llama.cpp backend from {}", @@ -57,12 +54,20 @@ impl LlamaCppBackend { fn scheduler_loop(mut backend: UniquePtr) { println!("Scheduler loop"); - let tokens = [128000i32, 5159, 836, 374, 23809]; - let mut generated = vec![0i32; 128]; - match backend - .pin_mut() - .generate(&tokens, &mut generated, 40, 32, 1.0, 1.0, 1.0, 1.0, 2014) - { + let tokens = [128000u32, 5159, 836, 374, 23809]; + let mut generated = vec![0u32; 16]; + let generation_params = GenerationParams { + max_new_tokens: generated.len() as u32, + }; + let sampling_params = SamplingParams::default(); + + match backend.pin_mut().generate( + &tokens, + &mut generated, + &generation_params, + &sampling_params, + |new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"), + ) { Ok(n_tokens) => { generated.truncate(n_tokens); println!("Generated {} tokens -> {:?}", n_tokens, generated); diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 673fe130..9fb79501 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -1,17 +1,56 @@ +use crate::ffi::SamplingParams; + pub mod backend; -#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")] +impl Default for SamplingParams { + fn default() -> Self { + Self { + top_k: u32::MAX, + top_p: 1.0f32, + frequency_penalty: 0.0f32, + repetition_penalty: 0.0f32, + seed: 2014u64, + } + } +} + +#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] mod ffi { + struct GenerationParams { + max_new_tokens: u32, + } + + struct SamplingParams { + top_k: u32, + top_p: f32, + frequency_penalty: f32, + repetition_penalty: f32, + seed: u64, + } + unsafe extern "C++" { include!("backends/llamacpp/csrc/ffi.hpp"); + #[cxx_name = "generation_params_t"] + type GenerationParams; + + #[cxx_name = "sampling_params_t"] + type SamplingParams; + /// Represent an instance of the llama.cpp backend instance on C++ side + #[cxx_name = "llama_cpp_backend_impl_t"] type LlamaCppBackendImpl; - #[rust_name = "create_llamacpp_backend"] - fn CreateLlamaCppBackendImpl( - modelPath: &str, - n_threads: u16, - ) -> Result>; + #[rust_name = "create_single_worker_backend"] + fn create_single_worker_backend(modelPath: &str) -> Result>; + + fn generate( + self: Pin<&mut LlamaCppBackendImpl>, + tokens: &[u32], + generated: &mut [u32], + generation_params: &GenerationParams, + sampling_params: &SamplingParams, + callback: fn(u32, bool), + ) -> Result; } } diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 3920da21..62f81848 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -161,7 +161,7 @@ async fn main() -> Result<(), RouterError> { } } - let backend = LlamaCppBackend::new(gguf_path, cores_per_instance)?; + let backend = LlamaCppBackend::new(gguf_path)?; // Run server server::run(