feat(backend): simplify overall cpp structure

This commit is contained in:
Morgan Funtowicz 2024-11-09 22:10:33 +01:00
parent 4f5397c414
commit 86d30aea43
7 changed files with 144 additions and 321 deletions

View File

@ -49,43 +49,28 @@ namespace huggingface::tgi::backends::llamacpp {
} }
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
return llama_sampler_ptr(pSampler, llama_sampler_deleter); return {pSampler, llama_sampler_deleter};
} }
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params) worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params)
: mModel_(model), mParams_(params) { : model_(model), context_(llama_new_context_with_model(model_.get(), params)) {
#ifdef TGI_LLAMACPP_BACKEND_DEBUG #ifdef TGI_LLAMACPP_BACKEND_DEBUG
char modelName[256]; char modelName[256];
llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName)); llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName)); SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
#endif #endif
} }
void worker_t::loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const { std::expected<size_t, backend_error_t>
auto *context = llama_new_context_with_model(mModel_.get(), mParams_); worker_t::generate(const generation_context_t &generation_context,
const std::optional<llama_decode_callback> &callback) const {
while (!driver.stop_requested()) {
const auto generation_context = backlog.front();
generate(context, generation_context, std::nullopt);
backlog.pop();
SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size());
}
llama_free(context);
}
size_t worker_t::generate(
llama_context *context,
const generation_context_t &generation_context,
const std::optional<llama_decode_callback> &callback) const {
// Store information about context and generation size // Store information about context and generation size
const auto callback_ = callback.value_or(llama_void_callback);
auto max_new_tokens = generation_context.generation_params.max_new_tokens; auto max_new_tokens = generation_context.generation_params.max_new_tokens;
// Convert sampling params to what llama.cpp is looking for // Convert sampling params to what llama.cpp is looking for
auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get()); auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get());
// Set up the prompt // Set up the prompt
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
@ -94,11 +79,10 @@ namespace huggingface::tgi::backends::llamacpp {
// Decode // Decode
auto n_decoded_tokens = 0; auto n_decoded_tokens = 0;
for (bool generating = true; generating; ++n_decoded_tokens) { for (bool generating = true; generating; ++n_decoded_tokens) {
const auto callback_ = callback.value_or(llama_void_callback);
#ifdef TGI_LLAMACPP_BACKEND_DEBUG #ifdef TGI_LLAMACPP_BACKEND_DEBUG
const auto start = std::chrono::steady_clock::now(); const auto start = std::chrono::steady_clock::now();
const auto status = llama_decode(context, batch); const auto status = llama_decode(context_.get(), batch);
const auto end = std::chrono::steady_clock::now(); const auto end = std::chrono::steady_clock::now();
const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency);
@ -108,8 +92,8 @@ namespace huggingface::tgi::backends::llamacpp {
batch.n_tokens = 0; batch.n_tokens = 0;
if (LLAMA_SUCCESS(status)) [[likely]] { if (LLAMA_SUCCESS(status)) [[likely]] {
// Sample the new token // Sample the new token
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1);
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id); auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
auto new_token_logits = 0.0f; // TODO: return logit auto new_token_logits = 0.0f; // TODO: return logit
// Handle termination cases // Handle termination cases
@ -119,11 +103,8 @@ namespace huggingface::tgi::backends::llamacpp {
generating = !(has_reach_max_tokens | has_reach_eog); generating = !(has_reach_max_tokens | has_reach_eog);
// Bubble up the generated token if a callback is provided // Bubble up the generated token if a callback is provided
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_), const auto should_stop =
new_token_id, std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1);
new_token_logits,
!generating,
n_decoded_tokens + 1);
generating ^= should_stop; generating ^= should_stop;
batch = llama_batch_get_one(&new_token_id, 1); batch = llama_batch_get_one(&new_token_id, 1);
@ -132,62 +113,4 @@ namespace huggingface::tgi::backends::llamacpp {
return n_decoded_tokens; return n_decoded_tokens;
} }
backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); }
backend_base_t::~backend_base_t() { llama_backend_free(); }
std::expected<std::vector<llama_token>, backend_error_t> backend_base_t::generate(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback
) {
// TODO: Should we provide a way to change this value?
auto generated = std::vector<llama_token>(2 << 8);
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
size_t num_generated_tokens) -> bool {
generated.emplace_back(new_token_id);
if (callback.has_value())
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
return true;
};
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
return generated;
}
/** Single worker_t Backend impl **/
single_worker_backend_t::single_worker_backend_t(llama_model *model,
const std::optional<llama_context_params> &params)
: backend_base_t(model),
mContext_(llama_context_factory(model)),
mWorker_(mModel_, params.value_or(llama_context_default_params())) {
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
};
std::expected<size_t, backend_error_t>
single_worker_backend_t::stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) {
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback);
}
std::expected<size_t, backend_error_t>
multi_worker_backend_t::stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) {
SPDLOG_WARN("Not implemented for multi_worker_t");
return 0;
}
} }

View File

@ -76,8 +76,8 @@ namespace huggingface::tgi::backends::llamacpp {
*/ */
class worker_t { class worker_t {
private: private:
const std::shared_ptr<llama_model> mModel_; std::shared_ptr<llama_model> model_;
const llama_context_params mParams_; llama_context_ptr context_;
public: public:
/** /**
@ -85,7 +85,7 @@ namespace huggingface::tgi::backends::llamacpp {
* @param model * @param model
* @param params * @param params
*/ */
worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params); worker_t(std::shared_ptr<llama_model>, const llama_context_params &);
/** /**
* *
@ -93,108 +93,8 @@ namespace huggingface::tgi::backends::llamacpp {
* @param generation_context * @param generation_context
* @param callback * @param callback
*/ */
size_t [[nodiscard]] std::expected<size_t, backend_error_t>
generate(llama_context *, const generation_context_t &, const std::optional<llama_decode_callback> &) const; generate(const generation_context_t &, const std::optional<llama_decode_callback> &) const;
/**
*
*/
void loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const;
};
class backend_base_t {
protected:
std::shared_ptr<llama_model> mModel_;
public:
/**
*
* @param model
*/
explicit backend_base_t(llama_model *model);
/**
* Destructor
*/
~backend_base_t();
/**
*
* @param tokens
* @param generation_params
* @param sampling_params
* @param callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
std::expected<std::vector<llama_token>, backend_error_t> generate(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback = std::nullopt
);
/**
*
* @param tokens
* @param generation_params
* @param sampling_params
* @params callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
virtual std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) = 0;
};
class single_worker_backend_t : backend_base_t {
private:
constexpr static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
auto llParams = llama_context_default_params();
llParams.flash_attn = true;
llParams.n_batch = 1;
llParams.n_threads = 1;
llParams.no_perf = true;
llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL;
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
};
llama_context_ptr mContext_;
worker_t mWorker_;
public:
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
using backend_base_t::generate;
std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback) override;
};
class multi_worker_backend_t : backend_base_t {
private:
llama_context_ptr mContext_;
public:
using backend_base_t::generate;
std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback) override;
}; };
} }

View File

@ -7,58 +7,41 @@
#include <exception> #include <exception>
#include <filesystem> #include <filesystem>
#include <memory>
#include <string_view> #include <string_view>
#include <variant> #include <variant>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "backend.hpp"
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
struct generation_params_t; class llama_cpp_worker_frontend_t;
struct sampling_params_t;
class llama_cpp_backend_impl_t;
} }
#include "backend.hpp"
#include "backends/llamacpp/src/lib.rs.h" #include "backends/llamacpp/src/lib.rs.h"
#include "rust/cxx.h" #include "rust/cxx.h"
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
// Concept identifying types which have a .generate() -> size_t method to do in-place generation auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
template<typename T> auto make_shared_llama_model = [](llama_model *model) {
concept has_stream_method = requires( return std::shared_ptr<llama_model>(model, llama_model_deleter);
T t,
std::span<const llama_token> input_tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
llama_decode_callback callback
) {
{
t.stream(input_tokens, generation_params, sampling_params, callback)
} -> std::same_as<std::expected<size_t, backend_error_t>>;
}; };
static_assert(has_stream_method<single_worker_backend_t>, "single_worker_backend_t doesn't meet concept has_stream_method"); class llama_cpp_backend_exception_t : std::exception {};
static_assert(has_stream_method<multi_worker_backend_t>, "multi_worker_backend_t doesn't meet concept has_stream_method");
class llama_cpp_backend_exception_t : std::exception {
};
/** /**
* Llama.cpp backend interfacing with Rust FFI layer * Llama.cpp frontend over the worker interfacing with Rust FFI layer
*/ */
class llama_cpp_backend_impl_t { class llama_cpp_worker_frontend_t {
private: private:
std::variant<single_worker_backend_t, multi_worker_backend_t> mInner_; std::shared_ptr<llama_model> model_;
worker_t worker_;
public: public:
explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {} explicit llama_cpp_worker_frontend_t(llama_model *model):
model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {}
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
size_t stream( size_t stream(
rust::Slice<const uint32_t> input_tokens, rust::Slice<const uint32_t> input_tokens,
@ -67,41 +50,31 @@ namespace huggingface::tgi::backends::llamacpp {
InferContext *ctx, InferContext *ctx,
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
) { ) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T auto context_forwarding_callback =
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend) [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
-> std::expected<size_t, backend_error_t> { return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
};
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
return backend.stream(
input_tokens_v,
generation_params,
sampling_params,
context_forwarding_callback
);
}; };
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { // Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
return *result; return *result;
} else { } else {
throw llama_cpp_backend_exception_t(); throw llama_cpp_backend_exception_t {};
} }
} }
}; };
std::unique_ptr<llama_cpp_backend_impl_t> create_single_worker_backend(rust::Str modelPath) { std::unique_ptr<llama_cpp_worker_frontend_t> create_worker_frontend(rust::Str modelPath) {
const auto cxxPath = std::string(modelPath); const auto cxxPath = std::string(modelPath);
auto params = llama_model_default_params(); auto params = llama_model_default_params();
params.use_mmap = true; params.use_mmap = true;
auto *model = llama_load_model_from_file(cxxPath.c_str(), params); auto *model = (llama_load_model_from_file(cxxPath.c_str(), params));
return std::make_unique<llama_cpp_backend_impl_t>(single_worker_backend_t { model, std::nullopt }); return std::make_unique<llama_cpp_worker_frontend_t>(model);
} }
} }

View File

@ -1,16 +1,17 @@
// //
// Created by mfuntowicz on 10/3/24. // Created by mfuntowicz on 10/3/24.
// //
#include <memory>
#include <fmt/color.h> #include <llama.h>
#include <fmt/format.h>
#include <fmt/std.h>
#include <fmt/ranges.h>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <spdlog/fmt/ranges.h>s
#include "../csrc/backend.hpp" #include "../csrc/backend.hpp"
using namespace huggingface::tgi::backends::llamacpp; using namespace huggingface::tgi::backends::llamacpp;
const auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
int main(int argc, char **argv) { int main(int argc, char **argv) {
if (argc < 2) { if (argc < 2) {
fmt::print("No model folder provider"); fmt::print("No model folder provider");
@ -21,18 +22,28 @@ int main(int argc, char **argv) {
const auto modelPath = absolute(std::filesystem::path(argv[1])); const auto modelPath = absolute(std::filesystem::path(argv[1]));
const auto params = llama_model_default_params(); const auto params = llama_model_default_params();
auto *model = llama_load_model_from_file(modelPath.c_str(), params); auto model = std::unique_ptr<llama_model, decltype(llama_model_deleter)>(
llama_load_model_from_file(modelPath.c_str(), params)
);
auto backend = single_worker_backend_t(model, {}); auto prompt = "My name is Morgan";
auto tokens = std::vector<llama_token>(16);
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
false);
tokens.resize(nb_tokens);
auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}};
fmt::println("Tokenized: {}", tokens);
// generate // generate
const auto promptTokens = {128000, 5159, 836, 374, 23809, 11}; auto generated_tokens = std::vector<llama_token>(32);
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40}); const auto n_generated_tokens = backend.generate(
{{.max_new_tokens = 32}, {.top_k = 40}, tokens},
if (out.has_value()) [&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
fmt::print(FMT_STRING("Generated: {}"), *out); generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
else { return false;
const auto err = out.error(); }
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err)); );
} generated_tokens.resize(n_generated_tokens.value());
fmt::println("Generated {} tokens", generated_tokens);
} }

View File

@ -1,8 +1,9 @@
use crate::ffi::{ use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use std::ops::Deref;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Arc; use std::sync::Arc;
@ -21,7 +22,7 @@ use tracing::{debug, error, info};
type InferResult = Result<InferStreamResponse, InferError>; type InferResult = Result<InferStreamResponse, InferError>;
unsafe impl Send for LlamaCppBackendImpl {} unsafe impl Send for LlamaCppWorkerFrontend {}
impl From<&ValidParameters> for SamplingParams { impl From<&ValidParameters> for SamplingParams {
fn from(v: &ValidParameters) -> Self { fn from(v: &ValidParameters) -> Self {
@ -68,41 +69,54 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String), ModelInitializationFailed(PathBuf, String),
} }
pub struct LlamaCppBackend { // pub struct LlamaCppBackend {
backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>, // backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
_scheduler_handle: JoinHandle<()>, // _scheduler_handle: JoinHandle<()>,
// }
struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
handle: JoinHandle<()>,
}
pub enum LlamaCppBackend {
Single(LlamaCppWorker),
// Multi(Vec<LlamaCppWorker>)
} }
impl LlamaCppBackend { impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>( fn allocate_worker(
path: &Path,
) -> Result<UniquePtr<LlamaCppWorkerFrontend>, LlamaCppBackendError> {
create_worker_frontend(&path.display().to_string()).map_err(|ref err| {
LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string())
})
}
pub fn new<P: AsRef<Path>>(
model_path: P, model_path: P,
tokenizer: Tokenizer, tokenizer: Tokenizer,
num_cores_per_instance: u16,
) -> Result<Self, LlamaCppBackendError> { ) -> Result<Self, LlamaCppBackendError> {
let path = Arc::new(model_path.as_ref()); let shared_path = Arc::new(model_path);
let path = shared_path.deref().as_ref();
if !path.exists() { if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist( return Err(LlamaCppBackendError::ModelFileDoesntExist(
path.display().to_string(), path.display().to_string(),
)); ));
} }
let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { let worker = match num_cores_per_instance {
LlamaCppBackendError::ModelInitializationFailed( 0 => {
path.to_path_buf(), let worker = Self::allocate_worker(path)?;
err.what().to_string(), let (sender, receiver) = channel();
) let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
})?; LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
}
_ => panic!("No supported yet"),
};
info!( Ok(worker)
"Successfully initialized llama.cpp backend from {}",
path.display()
);
let (submitter, receiver) = channel();
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
Ok(Self {
backlog: submitter,
_scheduler_handle: handle,
})
} }
} }
@ -169,18 +183,16 @@ fn llama_generate_callback(
}; };
// Send back to the client // Send back to the client
let should_stop = if let Err(ref _err) = ctx.stream.send(response) { if let Err(ref _err) = ctx.stream.send(response) {
error!("Failed to send back the response to the client, cancelling request"); error!("Failed to send back the response to the client, cancelling request");
true true
} else { } else {
true false
}; }
should_stop
} }
unsafe fn scheduler_loop( fn scheduler_loop(
mut backend: UniquePtr<LlamaCppBackendImpl>, mut backend: UniquePtr<LlamaCppWorkerFrontend>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>, backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
@ -204,20 +216,23 @@ unsafe fn scheduler_loop(
generation, generation,
}); });
let boxed_ctx = Box::into_raw(ctx); // We leak the box to avoid it being freed after the first callback call
// when going out of scope
unsafe {
let boxed_ctx = Box::into_raw(ctx);
if let Err(e) = backend.pin_mut().stream(
&input_tokens,
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
}
if let Err(e) = backend.pin_mut().stream( // Make sure we re-keep track of the OpaqueStream box
&input_tokens, let _ = Box::from_raw(boxed_ctx);
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
} }
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
} }
} else { } else {
info!("IPC channel is closed, exiting the scheduler loop"); info!("IPC channel is closed, exiting the scheduler loop");
@ -244,11 +259,13 @@ impl Backend for LlamaCppBackend {
sampling_params, sampling_params,
}; };
match self.backlog.send((ctx, sx)) { match self {
Ok(_) => Ok(UnboundedReceiverStream::new(rx)), LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) {
Err(_) => Err(InferError::GenerationError( Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
"Failed to sent the request".to_string(), Err(_) => Err(InferError::GenerationError(
)), "Failed to sent the request".to_string(),
)),
},
} }
} else { } else {
Err(InferError::GenerationError( Err(InferError::GenerationError(

View File

@ -46,14 +46,13 @@ mod ffi {
type SamplingParams; type SamplingParams;
/// Represent an instance of the llama.cpp backend instance on C++ side /// Represent an instance of the llama.cpp backend instance on C++ side
#[cxx_name = "llama_cpp_backend_impl_t"] #[cxx_name = "llama_cpp_worker_frontend_t"]
type LlamaCppBackendImpl; type LlamaCppWorkerFrontend;
#[rust_name = "create_single_worker_backend"] fn create_worker_frontend(modelPath: &str) -> Result<UniquePtr<LlamaCppWorkerFrontend>>;
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
unsafe fn stream( unsafe fn stream(
self: Pin<&mut LlamaCppBackendImpl>, self: Pin<&mut LlamaCppWorkerFrontend>,
tokens: &[u32], tokens: &[u32],
generation_params: GenerationParams, generation_params: GenerationParams,
sampling_params: &SamplingParams, sampling_params: &SamplingParams,

View File

@ -37,8 +37,8 @@ struct Args {
port: u16, port: u16,
#[clap(long, env, help = "Path to GGUF model file(s) to load")] #[clap(long, env, help = "Path to GGUF model file(s) to load")]
gguf_path: PathBuf, gguf_path: PathBuf,
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")] #[clap(long, env, help = "Number of CPU core per instance(s)")]
// num_model_instance: u16, num_cores_per_instance: Option<u16>,
#[clap(long, env, required = true)] #[clap(long, env, required = true)]
tokenizer_name: String, tokenizer_name: String,
#[clap(long, env)] #[clap(long, env)]
@ -95,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
hostname, hostname,
port, port,
gguf_path, gguf_path,
// num_model_instance, num_cores_per_instance,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
@ -164,7 +164,7 @@ async fn main() -> Result<(), RouterError> {
}; };
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer"); .expect("Failed to retrieve tokenizer");
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?; let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
// Run server // Run server
server::run( server::run(