// // Created by mfuntowicz on 10/23/24. // #ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP #define TGI_LLAMA_CPP_BACKEND_FFI_HPP #include #include #include #include #include "backend.hpp" namespace huggingface::tgi::backends::llamacpp { struct generation_params_t; struct sampling_params_t; class llama_cpp_backend_impl_t; } #include "backends/llamacpp/src/lib.rs.h" #include "rust/cxx.h" namespace huggingface::tgi::backends::llamacpp { // Concept identifying types which have a .generate() -> size_t method to do in-place generation template concept has_emplace_generate = requires( T t, std::span input_tokens, std::span generated_tokens, const generation_params_t &generation_params, const sampling_params_t &sampling_params, llama_decode_callback callback ) { { t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback) } -> std::same_as>; }; static_assert(has_emplace_generate, "single_worker_backend_t doesn't meet concept is_generate_emplace_capable"); static_assert(has_emplace_generate, "multi_worker_backend_t doesn't meet concept is_generate_emplace_capable"); class llama_cpp_backend_exception_t : std::exception { }; /** * Llama.cpp backend interfacing with Rust FFI layer */ class llama_cpp_backend_impl_t { private: std::variant mInner_; public: explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {} explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {} size_t stream( rust::Slice input_tokens, rust::Slice generated_tokens, const generation_params_t generation_params, const sampling_params_t &sampling_params, OpaqueStream *stream, rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T auto inner_fw = [=, &sampling_params, &stream, &callback](T &&backend) -> std::expected { auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ callback(stream, new_token_id, logits, is_eos, n_generated_tokens); }; // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* auto input_tokens_v = std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); auto generated_tokens_v = std::span(reinterpret_cast(generated_tokens.data()), generated_tokens.size()); return backend.generate( input_tokens_v, generated_tokens_v, generation_params, sampling_params, context_forwarding_callback ); }; if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { return *result; } else { throw llama_cpp_backend_exception_t(); } } }; std::unique_ptr create_single_worker_backend(rust::Str modelPath) { const auto cxxPath = std::string(modelPath); auto params = llama_model_default_params(); params.use_mmap = true; auto *model = llama_load_model_from_file(cxxPath.c_str(), params); auto backend = single_worker_backend_t(model, std::nullopt); return std::make_unique(std::move(backend)); } } #endif //TGI_LLAMA_CPP_BACKEND_FFI_HPP