From 86d30aea43c6b858fa260aaa49b2c95320f97236 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sat, 9 Nov 2024 22:10:33 +0100 Subject: [PATCH] feat(backend): simplify overall cpp structure --- backends/llamacpp/csrc/backend.cpp | 103 ++++---------------------- backends/llamacpp/csrc/backend.hpp | 110 ++-------------------------- backends/llamacpp/csrc/ffi.hpp | 79 +++++++------------- backends/llamacpp/offline/main.cpp | 43 +++++++---- backends/llamacpp/src/backend.rs | 113 +++++++++++++++++------------ backends/llamacpp/src/lib.rs | 9 +-- backends/llamacpp/src/main.rs | 8 +- 7 files changed, 144 insertions(+), 321 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 11781273..837f87ea 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -49,43 +49,28 @@ namespace huggingface::tgi::backends::llamacpp { } llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); - return llama_sampler_ptr(pSampler, llama_sampler_deleter); + return {pSampler, llama_sampler_deleter}; } worker_t::worker_t(std::shared_ptr model, const llama_context_params ¶ms) - : mModel_(model), mParams_(params) { + : model_(model), context_(llama_new_context_with_model(model_.get(), params)) { #ifdef TGI_LLAMACPP_BACKEND_DEBUG char modelName[256]; llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName)); - SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName)); + SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName)); #endif } - void worker_t::loop(std::stop_source &driver, std::queue &backlog) const { - auto *context = llama_new_context_with_model(mModel_.get(), mParams_); - - while (!driver.stop_requested()) { - const auto generation_context = backlog.front(); - - generate(context, generation_context, std::nullopt); - backlog.pop(); - - SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size()); - } - - llama_free(context); - } - - size_t worker_t::generate( - llama_context *context, - const generation_context_t &generation_context, - const std::optional &callback) const { + std::expected + worker_t::generate(const generation_context_t &generation_context, + const std::optional &callback) const { // Store information about context and generation size + const auto callback_ = callback.value_or(llama_void_callback); auto max_new_tokens = generation_context.generation_params.max_new_tokens; // Convert sampling params to what llama.cpp is looking for - auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get()); + auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get()); // Set up the prompt auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); @@ -94,11 +79,10 @@ namespace huggingface::tgi::backends::llamacpp { // Decode auto n_decoded_tokens = 0; for (bool generating = true; generating; ++n_decoded_tokens) { - const auto callback_ = callback.value_or(llama_void_callback); #ifdef TGI_LLAMACPP_BACKEND_DEBUG const auto start = std::chrono::steady_clock::now(); - const auto status = llama_decode(context, batch); + const auto status = llama_decode(context_.get(), batch); const auto end = std::chrono::steady_clock::now(); const auto latency = std::chrono::duration_cast(end - start); SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); @@ -108,8 +92,8 @@ namespace huggingface::tgi::backends::llamacpp { batch.n_tokens = 0; if (LLAMA_SUCCESS(status)) [[likely]] { // Sample the new token - auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); - auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id); + auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1); + auto is_eog = llama_token_is_eog(model_.get(), new_token_id); auto new_token_logits = 0.0f; // TODO: return logit // Handle termination cases @@ -119,11 +103,8 @@ namespace huggingface::tgi::backends::llamacpp { generating = !(has_reach_max_tokens | has_reach_eog); // Bubble up the generated token if a callback is provided - const auto should_stop = std::invoke(std::forward(callback_), - new_token_id, - new_token_logits, - !generating, - n_decoded_tokens + 1); + const auto should_stop = + std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1); generating ^= should_stop; batch = llama_batch_get_one(&new_token_id, 1); @@ -132,62 +113,4 @@ namespace huggingface::tgi::backends::llamacpp { return n_decoded_tokens; } - - - backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); } - - backend_base_t::~backend_base_t() { llama_backend_free(); } - - std::expected, backend_error_t> backend_base_t::generate( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const std::optional &callback - ) { - // TODO: Should we provide a way to change this value? - auto generated = std::vector(2 << 8); - auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos, - size_t num_generated_tokens) -> bool { - generated.emplace_back(new_token_id); - - if (callback.has_value()) - return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens); - return true; - }; - - auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback); - return generated; - } - - - /** Single worker_t Backend impl **/ - - single_worker_backend_t::single_worker_backend_t(llama_model *model, - const std::optional ¶ms) - : backend_base_t(model), - mContext_(llama_context_factory(model)), - mWorker_(mModel_, params.value_or(llama_context_default_params())) { - llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); - }; - - std::expected - single_worker_backend_t::stream( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const llama_decode_callback &callback - ) { - return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback); - } - - std::expected - multi_worker_backend_t::stream( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const llama_decode_callback &callback - ) { - SPDLOG_WARN("Not implemented for multi_worker_t"); - return 0; - } } \ No newline at end of file diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 4abc202d..de37df75 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -76,8 +76,8 @@ namespace huggingface::tgi::backends::llamacpp { */ class worker_t { private: - const std::shared_ptr mModel_; - const llama_context_params mParams_; + std::shared_ptr model_; + llama_context_ptr context_; public: /** @@ -85,7 +85,7 @@ namespace huggingface::tgi::backends::llamacpp { * @param model * @param params */ - worker_t(std::shared_ptr model, const llama_context_params ¶ms); + worker_t(std::shared_ptr, const llama_context_params &); /** * @@ -93,108 +93,8 @@ namespace huggingface::tgi::backends::llamacpp { * @param generation_context * @param callback */ - size_t - generate(llama_context *, const generation_context_t &, const std::optional &) const; - - /** - * - */ - void loop(std::stop_source &driver, std::queue &backlog) const; - }; - - - class backend_base_t { - - protected: - std::shared_ptr mModel_; - - public: - - /** - * - * @param model - */ - explicit backend_base_t(llama_model *model); - - /** - * Destructor - */ - ~backend_base_t(); - - /** - * - * @param tokens - * @param generation_params - * @param sampling_params - * @param callback - * @return - */ - [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] - std::expected, backend_error_t> generate( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const std::optional &callback = std::nullopt - ); - - /** - * - * @param tokens - * @param generation_params - * @param sampling_params - * @params callback - * @return - */ - [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] - virtual std::expected stream( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const llama_decode_callback &callback - ) = 0; - }; - - - class single_worker_backend_t : backend_base_t { - private: - constexpr static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr { - auto llParams = llama_context_default_params(); - llParams.flash_attn = true; - llParams.n_batch = 1; - llParams.n_threads = 1; - llParams.no_perf = true; - llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL; - - return {llama_new_context_with_model(pModel, llParams), llama_context_deleter}; - }; - - llama_context_ptr mContext_; - worker_t mWorker_; - - public: - explicit single_worker_backend_t(llama_model *pModel, const std::optional &); - - using backend_base_t::generate; - - std::expected stream( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const llama_decode_callback &callback) override; - }; - - class multi_worker_backend_t : backend_base_t { - private: - llama_context_ptr mContext_; - - public: - using backend_base_t::generate; - - std::expected stream( - std::span tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const llama_decode_callback &callback) override; + [[nodiscard]] std::expected + generate(const generation_context_t &, const std::optional &) const; }; } diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 9daacf2c..51a524cb 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -7,58 +7,41 @@ #include #include +#include #include #include #include -#include "backend.hpp" namespace huggingface::tgi::backends::llamacpp { - struct generation_params_t; - struct sampling_params_t; - - class llama_cpp_backend_impl_t; + class llama_cpp_worker_frontend_t; } - +#include "backend.hpp" #include "backends/llamacpp/src/lib.rs.h" #include "rust/cxx.h" namespace huggingface::tgi::backends::llamacpp { - // Concept identifying types which have a .generate() -> size_t method to do in-place generation - template - concept has_stream_method = requires( - T t, - std::span input_tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - llama_decode_callback callback - ) { - { - t.stream(input_tokens, generation_params, sampling_params, callback) - } -> std::same_as>; + auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); }; + auto make_shared_llama_model = [](llama_model *model) { + return std::shared_ptr(model, llama_model_deleter); }; - static_assert(has_stream_method, "single_worker_backend_t doesn't meet concept has_stream_method"); - static_assert(has_stream_method, "multi_worker_backend_t doesn't meet concept has_stream_method"); - - class llama_cpp_backend_exception_t : std::exception { - - }; + class llama_cpp_backend_exception_t : std::exception {}; /** - * Llama.cpp backend interfacing with Rust FFI layer + * Llama.cpp frontend over the worker interfacing with Rust FFI layer */ - class llama_cpp_backend_impl_t { + class llama_cpp_worker_frontend_t { private: - std::variant mInner_; + std::shared_ptr model_; + worker_t worker_; public: - explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {} - - explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {} + explicit llama_cpp_worker_frontend_t(llama_model *model): + model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {} size_t stream( rust::Slice input_tokens, @@ -67,41 +50,31 @@ namespace huggingface::tgi::backends::llamacpp { InferContext *ctx, rust::Fn callback ) { - // Define the visitor lambda function which requires the has_emplace_generate constraint on T - auto inner_fw = [=, &sampling_params, &ctx, &callback](T &&backend) - -> std::expected { - - auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool { - return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens); - }; - - // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* - auto input_tokens_v = - std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); - - return backend.stream( - input_tokens_v, - generation_params, - sampling_params, - context_forwarding_callback - ); + auto context_forwarding_callback = + [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool { + return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens); }; - if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { + // Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token* + auto input_tokens_v = + std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); + + const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v}; + if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] { return *result; } else { - throw llama_cpp_backend_exception_t(); + throw llama_cpp_backend_exception_t {}; } } }; - std::unique_ptr create_single_worker_backend(rust::Str modelPath) { + std::unique_ptr create_worker_frontend(rust::Str modelPath) { const auto cxxPath = std::string(modelPath); auto params = llama_model_default_params(); params.use_mmap = true; - auto *model = llama_load_model_from_file(cxxPath.c_str(), params); - return std::make_unique(single_worker_backend_t { model, std::nullopt }); + auto *model = (llama_load_model_from_file(cxxPath.c_str(), params)); + return std::make_unique(model); } } diff --git a/backends/llamacpp/offline/main.cpp b/backends/llamacpp/offline/main.cpp index 7eb7dbde..721abf05 100644 --- a/backends/llamacpp/offline/main.cpp +++ b/backends/llamacpp/offline/main.cpp @@ -1,16 +1,17 @@ // // Created by mfuntowicz on 10/3/24. // +#include -#include -#include -#include -#include +#include #include +#include s #include "../csrc/backend.hpp" using namespace huggingface::tgi::backends::llamacpp; +const auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); }; + int main(int argc, char **argv) { if (argc < 2) { fmt::print("No model folder provider"); @@ -18,21 +19,31 @@ int main(int argc, char **argv) { } spdlog::set_level(spdlog::level::debug); - + const auto modelPath = absolute(std::filesystem::path(argv[1])); const auto params = llama_model_default_params(); - auto *model = llama_load_model_from_file(modelPath.c_str(), params); + auto model = std::unique_ptr( + llama_load_model_from_file(modelPath.c_str(), params) + ); - auto backend = single_worker_backend_t(model, {}); + auto prompt = "My name is Morgan"; + auto tokens = std::vector(16); + const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true, + false); + tokens.resize(nb_tokens); + auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}}; + + fmt::println("Tokenized: {}", tokens); // generate - const auto promptTokens = {128000, 5159, 836, 374, 23809, 11}; - const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40}); - - if (out.has_value()) - fmt::print(FMT_STRING("Generated: {}"), *out); - else { - const auto err = out.error(); - fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast(err)); - } + auto generated_tokens = std::vector(32); + const auto n_generated_tokens = backend.generate( + {{.max_new_tokens = 32}, {.top_k = 40}, tokens}, + [&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool { + generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id); + return false; + } + ); + generated_tokens.resize(n_generated_tokens.value()); + fmt::println("Generated {} tokens", generated_tokens); } diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 8214c36a..8e36aa63 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,8 +1,9 @@ use crate::ffi::{ - create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, + create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams, }; use async_trait::async_trait; use cxx::UniquePtr; +use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Arc; @@ -21,7 +22,7 @@ use tracing::{debug, error, info}; type InferResult = Result; -unsafe impl Send for LlamaCppBackendImpl {} +unsafe impl Send for LlamaCppWorkerFrontend {} impl From<&ValidParameters> for SamplingParams { fn from(v: &ValidParameters) -> Self { @@ -68,41 +69,54 @@ pub enum LlamaCppBackendError { ModelInitializationFailed(PathBuf, String), } -pub struct LlamaCppBackend { - backlog: Sender<(GenerationContext, UnboundedSender)>, - _scheduler_handle: JoinHandle<()>, +// pub struct LlamaCppBackend { +// backlog: Sender<(GenerationContext, UnboundedSender)>, +// _scheduler_handle: JoinHandle<()>, +// } + +struct LlamaCppWorker { + sender: Sender<(GenerationContext, UnboundedSender)>, + handle: JoinHandle<()>, +} + +pub enum LlamaCppBackend { + Single(LlamaCppWorker), + // Multi(Vec) } impl LlamaCppBackend { - pub fn new + Send>( + fn allocate_worker( + path: &Path, + ) -> Result, LlamaCppBackendError> { + create_worker_frontend(&path.display().to_string()).map_err(|ref err| { + LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string()) + }) + } + + pub fn new>( model_path: P, tokenizer: Tokenizer, + num_cores_per_instance: u16, ) -> Result { - let path = Arc::new(model_path.as_ref()); + let shared_path = Arc::new(model_path); + let path = shared_path.deref().as_ref(); if !path.exists() { return Err(LlamaCppBackendError::ModelFileDoesntExist( path.display().to_string(), )); } - let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { - LlamaCppBackendError::ModelInitializationFailed( - path.to_path_buf(), - err.what().to_string(), - ) - })?; + let worker = match num_cores_per_instance { + 0 => { + let worker = Self::allocate_worker(path)?; + let (sender, receiver) = channel(); + let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver)); + LlamaCppBackend::Single(LlamaCppWorker { sender, handle }) + } + _ => panic!("No supported yet"), + }; - info!( - "Successfully initialized llama.cpp backend from {}", - path.display() - ); - - let (submitter, receiver) = channel(); - let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) }; - Ok(Self { - backlog: submitter, - _scheduler_handle: handle, - }) + Ok(worker) } } @@ -169,18 +183,16 @@ fn llama_generate_callback( }; // Send back to the client - let should_stop = if let Err(ref _err) = ctx.stream.send(response) { + if let Err(ref _err) = ctx.stream.send(response) { error!("Failed to send back the response to the client, cancelling request"); true } else { - true - }; - - should_stop + false + } } -unsafe fn scheduler_loop( - mut backend: UniquePtr, +fn scheduler_loop( + mut backend: UniquePtr, tokenizer: Tokenizer, backlog: Receiver<(GenerationContext, UnboundedSender)>, ) { @@ -204,20 +216,23 @@ unsafe fn scheduler_loop( generation, }); - let boxed_ctx = Box::into_raw(ctx); + // We leak the box to avoid it being freed after the first callback call + // when going out of scope + unsafe { + let boxed_ctx = Box::into_raw(ctx); + if let Err(e) = backend.pin_mut().stream( + &input_tokens, + generation_params, + &sampling_params, + boxed_ctx, + llama_generate_callback, + ) { + error!("Error while decoding tokens... {}", e.what()); + } - if let Err(e) = backend.pin_mut().stream( - &input_tokens, - generation_params, - &sampling_params, - boxed_ctx, - llama_generate_callback, - ) { - error!("Error while decoding tokens... {}", e.what()); + // Make sure we re-keep track of the OpaqueStream box + let _ = Box::from_raw(boxed_ctx); } - - // Make sure we re-keep track of the OpaqueStream box - let _ = Box::from_raw(boxed_ctx); } } else { info!("IPC channel is closed, exiting the scheduler loop"); @@ -244,11 +259,13 @@ impl Backend for LlamaCppBackend { sampling_params, }; - match self.backlog.send((ctx, sx)) { - Ok(_) => Ok(UnboundedReceiverStream::new(rx)), - Err(_) => Err(InferError::GenerationError( - "Failed to sent the request".to_string(), - )), + match self { + LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) { + Ok(_) => Ok(UnboundedReceiverStream::new(rx)), + Err(_) => Err(InferError::GenerationError( + "Failed to sent the request".to_string(), + )), + }, } } else { Err(InferError::GenerationError( diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index abcdd1fa..4f0fa800 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -46,14 +46,13 @@ mod ffi { type SamplingParams; /// Represent an instance of the llama.cpp backend instance on C++ side - #[cxx_name = "llama_cpp_backend_impl_t"] - type LlamaCppBackendImpl; + #[cxx_name = "llama_cpp_worker_frontend_t"] + type LlamaCppWorkerFrontend; - #[rust_name = "create_single_worker_backend"] - fn create_single_worker_backend(modelPath: &str) -> Result>; + fn create_worker_frontend(modelPath: &str) -> Result>; unsafe fn stream( - self: Pin<&mut LlamaCppBackendImpl>, + self: Pin<&mut LlamaCppWorkerFrontend>, tokens: &[u32], generation_params: GenerationParams, sampling_params: &SamplingParams, diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index c5d735ab..a2abd555 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -37,8 +37,8 @@ struct Args { port: u16, #[clap(long, env, help = "Path to GGUF model file(s) to load")] gguf_path: PathBuf, - // #[clap(long, env, default_value = "1", help = "Number of model instance(s)")] - // num_model_instance: u16, + #[clap(long, env, help = "Number of CPU core per instance(s)")] + num_cores_per_instance: Option, #[clap(long, env, required = true)] tokenizer_name: String, #[clap(long, env)] @@ -95,7 +95,7 @@ async fn main() -> Result<(), RouterError> { hostname, port, gguf_path, - // num_model_instance, + num_cores_per_instance, tokenizer_name, tokenizer_config_path, revision, @@ -164,7 +164,7 @@ async fn main() -> Result<(), RouterError> { }; let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) .expect("Failed to retrieve tokenizer"); - let backend = LlamaCppBackend::new(gguf_path, tokenizer)?; + let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?; // Run server server::run(