mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
feat(backend): full rework of the backend internal to safer c++
This commit is contained in:
parent
6a5f6b0755
commit
d52b4c4978
@ -21,7 +21,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
batch.token[i] = input_tokens[i];
|
batch.token[i] = input_tokens[i];
|
||||||
batch.pos[i] = i;
|
batch.pos[i] = i;
|
||||||
batch.n_seq_id[i] = 1;
|
batch.n_seq_id[i] = 1;
|
||||||
batch.seq_id[i] = 0;
|
batch.seq_id[i] = nullptr;
|
||||||
batch.logits[i] = false;
|
batch.logits[i] = false;
|
||||||
++batch.n_tokens;
|
++batch.n_tokens;
|
||||||
}
|
}
|
||||||
@ -84,7 +84,6 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
const generation_context_t &generation_context,
|
const generation_context_t &generation_context,
|
||||||
const std::optional<llama_decode_callback> &callback) const {
|
const std::optional<llama_decode_callback> &callback) const {
|
||||||
// Store information about context and generation size
|
// Store information about context and generation size
|
||||||
auto prompt_length = std::ssize(generation_context.input_tokens);
|
|
||||||
auto max_new_tokens = generation_context.generation_params.max_new_tokens;
|
auto max_new_tokens = generation_context.generation_params.max_new_tokens;
|
||||||
|
|
||||||
// Convert sampling params to what llama.cpp is looking for
|
// Convert sampling params to what llama.cpp is looking for
|
||||||
@ -168,4 +167,15 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
) {
|
) {
|
||||||
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback);
|
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::expected<size_t, backend_error_t>
|
||||||
|
multi_worker_backend_t::generate(
|
||||||
|
std::span<const llama_token>,
|
||||||
|
std::span<llama_token>,
|
||||||
|
const generation_params_t &generation_params,
|
||||||
|
const sampling_params_t &sampling_params,
|
||||||
|
const std::optional<llama_decode_callback> &callback) {
|
||||||
|
SPDLOG_ERROR("Not implemented yet");
|
||||||
|
return 0uz;
|
||||||
|
}
|
||||||
}
|
}
|
@ -180,8 +180,20 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback
|
const std::optional<llama_decode_callback> &callback
|
||||||
) override;
|
) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class multi_worker_backend_t : backend_base_t {
|
||||||
|
private:
|
||||||
|
llama_context_smart_ptr mContext_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::expected<size_t, backend_error_t> generate(
|
||||||
|
std::span<const llama_token>,
|
||||||
|
std::span<llama_token>,
|
||||||
|
const generation_params_t &generation_params,
|
||||||
|
const sampling_params_t &sampling_params,
|
||||||
|
const std::optional<llama_decode_callback> &callback
|
||||||
|
) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,37 +12,93 @@
|
|||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llamacpp::impl {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
class LlamaCppBackendImpl;
|
struct generation_params_t;
|
||||||
|
struct sampling_params_t;
|
||||||
|
|
||||||
|
class llama_cpp_backend_impl_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#include "backends/llamacpp/src/lib.rs.h"
|
#include "backends/llamacpp/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llamacpp::impl {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
|
|
||||||
class LlamaCppBackendException : std::exception {
|
// Concept identifying types which have a .generate() -> size_t method to do in-place generation
|
||||||
|
template<typename T>
|
||||||
|
concept has_emplace_generate = requires(
|
||||||
|
T t,
|
||||||
|
std::span<const llama_token> input_tokens,
|
||||||
|
std::span<llama_token> generated_tokens,
|
||||||
|
const generation_params_t &generation_params,
|
||||||
|
const sampling_params_t &sampling_params,
|
||||||
|
llama_decode_callback callback
|
||||||
|
) {
|
||||||
|
{
|
||||||
|
t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback)
|
||||||
|
} -> std::same_as<std::expected<size_t, backend_error_t>>;
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(has_emplace_generate<single_worker_backend_t>,
|
||||||
|
"single_worker_backend_t doesn't meet concept is_generate_emplace_capable");
|
||||||
|
static_assert(has_emplace_generate<multi_worker_backend_t>,
|
||||||
|
"multi_worker_backend_t doesn't meet concept is_generate_emplace_capable");
|
||||||
|
|
||||||
|
class llama_cpp_backend_exception_t : std::exception {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class LlamaCppBackendImpl {
|
/**
|
||||||
|
* Llama.cpp backend interfacing with Rust FFI layer
|
||||||
|
*/
|
||||||
|
class llama_cpp_backend_impl_t {
|
||||||
private:
|
private:
|
||||||
BackendBase _inner;
|
std::variant<single_worker_backend_t, multi_worker_backend_t> mInner_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LlamaCppBackendImpl(llama_model *model) : _inner(model) {}
|
explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
||||||
|
|
||||||
|
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
||||||
|
|
||||||
|
size_t generate(
|
||||||
|
rust::Slice<const uint32_t> input_tokens,
|
||||||
|
rust::Slice <uint32_t> generated_tokens,
|
||||||
|
const generation_params_t &generation_params,
|
||||||
|
const sampling_params_t &sampling_params,
|
||||||
|
rust::Fn<void(uint32_t, bool)> callback
|
||||||
|
) {
|
||||||
|
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||||
|
static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend)
|
||||||
|
-> std::expected<size_t, backend_error_t> {
|
||||||
|
|
||||||
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||||
|
auto input_tokens_v =
|
||||||
|
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
||||||
|
auto generated_tokens_v =
|
||||||
|
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
|
||||||
|
|
||||||
|
return backend.generate(
|
||||||
|
input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) {
|
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
|
||||||
const auto cxxPath = std::string_view(modelPath);
|
return *result;
|
||||||
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath), nThreads); maybe.has_value()) {
|
|
||||||
auto [model, context] = *maybe;
|
|
||||||
return std::make_unique<LlamaCppBackendImpl>(model, context);
|
|
||||||
} else {
|
} else {
|
||||||
throw LlamaCppBackendException();
|
throw llama_cpp_backend_exception_t();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<llama_cpp_backend_impl_t> create_single_worker_backend(rust::Str modelPath) {
|
||||||
|
const auto cxxPath = std::string(modelPath);
|
||||||
|
auto params = llama_model_default_params();
|
||||||
|
params.use_mmap = true;
|
||||||
|
|
||||||
|
auto *model = llama_load_model_from_file(cxxPath.c_str(), params);
|
||||||
|
auto backend = single_worker_backend_t(model, std::nullopt);
|
||||||
|
return std::make_unique<llama_cpp_backend_impl_t>(std::move(backend));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
|
use crate::ffi::{
|
||||||
|
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
||||||
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::{Exception, UniquePtr};
|
use cxx::{Exception, UniquePtr};
|
||||||
use std::ops::Deref;
|
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread::spawn;
|
use std::thread::spawn;
|
||||||
@ -25,10 +26,7 @@ pub enum LlamaCppBackendError {
|
|||||||
pub struct LlamaCppBackend {}
|
pub struct LlamaCppBackend {}
|
||||||
|
|
||||||
impl LlamaCppBackend {
|
impl LlamaCppBackend {
|
||||||
pub fn new<P: AsRef<Path> + Send>(
|
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
||||||
model_path: P,
|
|
||||||
n_threads: u16,
|
|
||||||
) -> Result<Self, LlamaCppBackendError> {
|
|
||||||
let path = Arc::new(model_path.as_ref());
|
let path = Arc::new(model_path.as_ref());
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||||
@ -36,8 +34,7 @@ impl LlamaCppBackend {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut backend =
|
let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
|
||||||
create_llamacpp_backend(path.to_str().unwrap(), n_threads).map_err(|err| {
|
|
||||||
LlamaCppBackendError::ModelInitializationFailed(
|
LlamaCppBackendError::ModelInitializationFailed(
|
||||||
path.to_path_buf(),
|
path.to_path_buf(),
|
||||||
err.what().to_string(),
|
err.what().to_string(),
|
||||||
@ -57,12 +54,20 @@ impl LlamaCppBackend {
|
|||||||
|
|
||||||
fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {
|
fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {
|
||||||
println!("Scheduler loop");
|
println!("Scheduler loop");
|
||||||
let tokens = [128000i32, 5159, 836, 374, 23809];
|
let tokens = [128000u32, 5159, 836, 374, 23809];
|
||||||
let mut generated = vec![0i32; 128];
|
let mut generated = vec![0u32; 16];
|
||||||
match backend
|
let generation_params = GenerationParams {
|
||||||
.pin_mut()
|
max_new_tokens: generated.len() as u32,
|
||||||
.generate(&tokens, &mut generated, 40, 32, 1.0, 1.0, 1.0, 1.0, 2014)
|
};
|
||||||
{
|
let sampling_params = SamplingParams::default();
|
||||||
|
|
||||||
|
match backend.pin_mut().generate(
|
||||||
|
&tokens,
|
||||||
|
&mut generated,
|
||||||
|
&generation_params,
|
||||||
|
&sampling_params,
|
||||||
|
|new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"),
|
||||||
|
) {
|
||||||
Ok(n_tokens) => {
|
Ok(n_tokens) => {
|
||||||
generated.truncate(n_tokens);
|
generated.truncate(n_tokens);
|
||||||
println!("Generated {} tokens -> {:?}", n_tokens, generated);
|
println!("Generated {} tokens -> {:?}", n_tokens, generated);
|
||||||
|
@ -1,17 +1,56 @@
|
|||||||
|
use crate::ffi::SamplingParams;
|
||||||
|
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp::impl")]
|
impl Default for SamplingParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
top_k: u32::MAX,
|
||||||
|
top_p: 1.0f32,
|
||||||
|
frequency_penalty: 0.0f32,
|
||||||
|
repetition_penalty: 0.0f32,
|
||||||
|
seed: 2014u64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
struct GenerationParams {
|
||||||
|
max_new_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SamplingParams {
|
||||||
|
top_k: u32,
|
||||||
|
top_p: f32,
|
||||||
|
frequency_penalty: f32,
|
||||||
|
repetition_penalty: f32,
|
||||||
|
seed: u64,
|
||||||
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/llamacpp/csrc/ffi.hpp");
|
include!("backends/llamacpp/csrc/ffi.hpp");
|
||||||
|
|
||||||
|
#[cxx_name = "generation_params_t"]
|
||||||
|
type GenerationParams;
|
||||||
|
|
||||||
|
#[cxx_name = "sampling_params_t"]
|
||||||
|
type SamplingParams;
|
||||||
|
|
||||||
/// Represent an instance of the llama.cpp backend instance on C++ side
|
/// Represent an instance of the llama.cpp backend instance on C++ side
|
||||||
|
#[cxx_name = "llama_cpp_backend_impl_t"]
|
||||||
type LlamaCppBackendImpl;
|
type LlamaCppBackendImpl;
|
||||||
|
|
||||||
#[rust_name = "create_llamacpp_backend"]
|
#[rust_name = "create_single_worker_backend"]
|
||||||
fn CreateLlamaCppBackendImpl(
|
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
||||||
modelPath: &str,
|
|
||||||
n_threads: u16,
|
fn generate(
|
||||||
) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
self: Pin<&mut LlamaCppBackendImpl>,
|
||||||
|
tokens: &[u32],
|
||||||
|
generated: &mut [u32],
|
||||||
|
generation_params: &GenerationParams,
|
||||||
|
sampling_params: &SamplingParams,
|
||||||
|
callback: fn(u32, bool),
|
||||||
|
) -> Result<usize>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let backend = LlamaCppBackend::new(gguf_path, cores_per_instance)?;
|
let backend = LlamaCppBackend::new(gguf_path)?;
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
Loading…
Reference in New Issue
Block a user