feat(backend): full rework of the backend internal to safer c++

This commit is contained in:
Morgan Funtowicz 2024-10-31 17:51:57 +01:00
parent 6a5f6b0755
commit d52b4c4978
6 changed files with 166 additions and 44 deletions

View File

@ -21,7 +21,7 @@ namespace huggingface::tgi::backends::llamacpp {
batch.token[i] = input_tokens[i]; batch.token[i] = input_tokens[i];
batch.pos[i] = i; batch.pos[i] = i;
batch.n_seq_id[i] = 1; batch.n_seq_id[i] = 1;
batch.seq_id[i] = 0; batch.seq_id[i] = nullptr;
batch.logits[i] = false; batch.logits[i] = false;
++batch.n_tokens; ++batch.n_tokens;
} }
@ -84,13 +84,12 @@ namespace huggingface::tgi::backends::llamacpp {
const generation_context_t &generation_context, const generation_context_t &generation_context,
const std::optional<llama_decode_callback> &callback) const { const std::optional<llama_decode_callback> &callback) const {
// Store information about context and generation size // Store information about context and generation size
auto prompt_length = std::ssize(generation_context.input_tokens);
auto max_new_tokens = generation_context.generation_params.max_new_tokens; auto max_new_tokens = generation_context.generation_params.max_new_tokens;
// Convert sampling params to what llama.cpp is looking for // Convert sampling params to what llama.cpp is looking for
auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get()); auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get());
// Setup the prompt // Set up the prompt
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
auto batch = llama_batch_get_one(copy.data(), copy.size()); auto batch = llama_batch_get_one(copy.data(), copy.size());
@ -168,4 +167,15 @@ namespace huggingface::tgi::backends::llamacpp {
) { ) {
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback); return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback);
} }
std::expected<size_t, backend_error_t>
multi_worker_backend_t::generate(
std::span<const llama_token>,
std::span<llama_token>,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback) {
SPDLOG_ERROR("Not implemented yet");
return 0uz;
}
} }

View File

@ -180,8 +180,20 @@ namespace huggingface::tgi::backends::llamacpp {
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback const std::optional<llama_decode_callback> &callback
) override; ) override;
};
class multi_worker_backend_t : backend_base_t {
private:
llama_context_smart_ptr mContext_;
public:
std::expected<size_t, backend_error_t> generate(
std::span<const llama_token>,
std::span<llama_token>,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback
) override;
}; };
} }

View File

@ -12,36 +12,92 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "backend.hpp" #include "backend.hpp"
namespace huggingface::tgi::backends::llamacpp::impl { namespace huggingface::tgi::backends::llamacpp {
class LlamaCppBackendImpl; struct generation_params_t;
struct sampling_params_t;
class llama_cpp_backend_impl_t;
} }
#include "backends/llamacpp/src/lib.rs.h" #include "backends/llamacpp/src/lib.rs.h"
namespace huggingface::tgi::backends::llamacpp::impl { namespace huggingface::tgi::backends::llamacpp {
class LlamaCppBackendException : std::exception { // Concept identifying types which have a .generate() -> size_t method to do in-place generation
template<typename T>
concept has_emplace_generate = requires(
T t,
std::span<const llama_token> input_tokens,
std::span<llama_token> generated_tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
llama_decode_callback callback
) {
{
t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback)
} -> std::same_as<std::expected<size_t, backend_error_t>>;
};
static_assert(has_emplace_generate<single_worker_backend_t>,
"single_worker_backend_t doesn't meet concept is_generate_emplace_capable");
static_assert(has_emplace_generate<multi_worker_backend_t>,
"multi_worker_backend_t doesn't meet concept is_generate_emplace_capable");
class llama_cpp_backend_exception_t : std::exception {
}; };
class LlamaCppBackendImpl { /**
* Llama.cpp backend interfacing with Rust FFI layer
*/
class llama_cpp_backend_impl_t {
private: private:
BackendBase _inner; std::variant<single_worker_backend_t, multi_worker_backend_t> mInner_;
public: public:
LlamaCppBackendImpl(llama_model *model) : _inner(model) {} explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
size_t generate(
rust::Slice<const uint32_t> input_tokens,
rust::Slice <uint32_t> generated_tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
rust::Fn<void(uint32_t, bool)> callback
) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend)
-> std::expected<size_t, backend_error_t> {
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
auto generated_tokens_v =
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
return backend.generate(
input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback);
};
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
return *result;
} else {
throw llama_cpp_backend_exception_t();
}
}
}; };
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) { std::unique_ptr<llama_cpp_backend_impl_t> create_single_worker_backend(rust::Str modelPath) {
const auto cxxPath = std::string_view(modelPath); const auto cxxPath = std::string(modelPath);
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath), nThreads); maybe.has_value()) { auto params = llama_model_default_params();
auto [model, context] = *maybe; params.use_mmap = true;
return std::make_unique<LlamaCppBackendImpl>(model, context);
} else { auto *model = llama_load_model_from_file(cxxPath.c_str(), params);
throw LlamaCppBackendException(); auto backend = single_worker_backend_t(model, std::nullopt);
} return std::make_unique<llama_cpp_backend_impl_t>(std::move(backend));
} }
} }

View File

@ -1,7 +1,8 @@
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl}; use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
};
use async_trait::async_trait; use async_trait::async_trait;
use cxx::{Exception, UniquePtr}; use cxx::{Exception, UniquePtr};
use std::ops::Deref;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn; use std::thread::spawn;
@ -25,10 +26,7 @@ pub enum LlamaCppBackendError {
pub struct LlamaCppBackend {} pub struct LlamaCppBackend {}
impl LlamaCppBackend { impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>( pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
model_path: P,
n_threads: u16,
) -> Result<Self, LlamaCppBackendError> {
let path = Arc::new(model_path.as_ref()); let path = Arc::new(model_path.as_ref());
if !path.exists() { if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist( return Err(LlamaCppBackendError::ModelFileDoesntExist(
@ -36,13 +34,12 @@ impl LlamaCppBackend {
)); ));
} }
let mut backend = let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
create_llamacpp_backend(path.to_str().unwrap(), n_threads).map_err(|err| { LlamaCppBackendError::ModelInitializationFailed(
LlamaCppBackendError::ModelInitializationFailed( path.to_path_buf(),
path.to_path_buf(), err.what().to_string(),
err.what().to_string(), )
) })?;
})?;
info!( info!(
"Successfully initialized llama.cpp backend from {}", "Successfully initialized llama.cpp backend from {}",
@ -57,12 +54,20 @@ impl LlamaCppBackend {
fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) { fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {
println!("Scheduler loop"); println!("Scheduler loop");
let tokens = [128000i32, 5159, 836, 374, 23809]; let tokens = [128000u32, 5159, 836, 374, 23809];
let mut generated = vec![0i32; 128]; let mut generated = vec![0u32; 16];
match backend let generation_params = GenerationParams {
.pin_mut() max_new_tokens: generated.len() as u32,
.generate(&tokens, &mut generated, 40, 32, 1.0, 1.0, 1.0, 1.0, 2014) };
{ let sampling_params = SamplingParams::default();
match backend.pin_mut().generate(
&tokens,
&mut generated,
&generation_params,
&sampling_params,
|new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"),
) {
Ok(n_tokens) => { Ok(n_tokens) => {
generated.truncate(n_tokens); generated.truncate(n_tokens);
println!("Generated {} tokens -> {:?}", n_tokens, generated); println!("Generated {} tokens -> {:?}", n_tokens, generated);

View File

@ -1,17 +1,56 @@
use crate::ffi::SamplingParams;
pub mod backend; pub mod backend;
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")] impl Default for SamplingParams {
fn default() -> Self {
Self {
top_k: u32::MAX,
top_p: 1.0f32,
frequency_penalty: 0.0f32,
repetition_penalty: 0.0f32,
seed: 2014u64,
}
}
}
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
mod ffi { mod ffi {
struct GenerationParams {
max_new_tokens: u32,
}
struct SamplingParams {
top_k: u32,
top_p: f32,
frequency_penalty: f32,
repetition_penalty: f32,
seed: u64,
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/llamacpp/csrc/ffi.hpp"); include!("backends/llamacpp/csrc/ffi.hpp");
#[cxx_name = "generation_params_t"]
type GenerationParams;
#[cxx_name = "sampling_params_t"]
type SamplingParams;
/// Represent an instance of the llama.cpp backend instance on C++ side /// Represent an instance of the llama.cpp backend instance on C++ side
#[cxx_name = "llama_cpp_backend_impl_t"]
type LlamaCppBackendImpl; type LlamaCppBackendImpl;
#[rust_name = "create_llamacpp_backend"] #[rust_name = "create_single_worker_backend"]
fn CreateLlamaCppBackendImpl( fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
modelPath: &str,
n_threads: u16, fn generate(
) -> Result<UniquePtr<LlamaCppBackendImpl>>; self: Pin<&mut LlamaCppBackendImpl>,
tokens: &[u32],
generated: &mut [u32],
generation_params: &GenerationParams,
sampling_params: &SamplingParams,
callback: fn(u32, bool),
) -> Result<usize>;
} }
} }

View File

@ -161,7 +161,7 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let backend = LlamaCppBackend::new(gguf_path, cores_per_instance)?; let backend = LlamaCppBackend::new(gguf_path)?;
// Run server // Run server
server::run( server::run(