From 355d8a55b46f4ac56a8741bc3e3960a6bed2c03a Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 24 Oct 2024 09:56:40 +0200 Subject: [PATCH] feat(backend): wip Rust binding --- backends/llamacpp/CMakeLists.txt | 7 +++++++ backends/llamacpp/build.rs | 6 ++++-- backends/llamacpp/csrc/backend.hpp | 1 + backends/llamacpp/csrc/ffi.hpp | 19 +++++++++++++++++++ backends/llamacpp/src/backend.rs | 15 ++++++++++++++- backends/llamacpp/src/lib.rs | 9 +++++++-- 6 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 backends/llamacpp/csrc/ffi.hpp diff --git a/backends/llamacpp/CMakeLists.txt b/backends/llamacpp/CMakeLists.txt index 9f08d0f3..644db5ae 100644 --- a/backends/llamacpp/CMakeLists.txt +++ b/backends/llamacpp/CMakeLists.txt @@ -11,6 +11,13 @@ set(LLAMA_CPP_TARGET_CUDA_ARCHS "75-real;80-real;86-real;89-real;90-real" CACHE option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner") option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp") +if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + message(STATUS "Targeting libc++") + set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS}) +else() + message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}") +endif() + # Add dependencies include(cmake/fmt.cmake) include(cmake/spdlog.cmake) diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs index 26ea8d92..d84e517f 100644 --- a/backends/llamacpp/build.rs +++ b/backends/llamacpp/build.rs @@ -59,18 +59,20 @@ fn build_ffi_layer(deps_folder: &PathBuf) { CFG.include_prefix = "backends/llamacpp"; cxx_build::bridge("src/lib.rs") .static_flag(true) + .std("c++23") .include(deps_folder.join("fmt-src").join("include")) .include(deps_folder.join("spdlog-src").join("include")) .include(deps_folder.join("llama-src").join("common")) .include(deps_folder.join("llama-src").join("ggml").join("include")) .include(deps_folder.join("llama-src").join("include")) - .file("csrc/backend.cpp") - .std("c++23") + .include("csrc/backend.hpp") + .file("csrc/ffi.cpp") .compile(CMAKE_LLAMA_CPP_TARGET); println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=csrc/backend.hpp"); println!("cargo:rerun-if-changed=csrc/backend.cpp"); + println!("cargo:rerun-if-changed=csrc/ffi.hpp"); } fn main() { diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index e4c31ad6..7075642a 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #define LLAMA_SUCCESS(x) x == 0 diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp new file mode 100644 index 00000000..e924316e --- /dev/null +++ b/backends/llamacpp/csrc/ffi.hpp @@ -0,0 +1,19 @@ +// +// Created by mfuntowicz on 10/23/24. +// + +#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP +#define TGI_LLAMA_CPP_BACKEND_FFI_HPP + +#include "backend.hpp" +//#include "backends/llamacpp/src/lib.rs.h" + + +namespace huggingface::tgi::backends::llama { + class LlamaCppBackendImpl { + + }; +} + + +#endif //TGI_LLAMA_CPP_BACKEND_FFI_HPP diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 8af1067b..89daeee3 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,8 +1,21 @@ +use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl}; +use cxx::UniquePtr; +use std::path::Path; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; use tokio_stream::wrappers::UnboundedReceiverStream; -pub struct TgiLlamaCppBakend {} +pub struct TgiLlamaCppBakend { + backend: UniquePtr, +} + +impl TgiLlamaCppBakend { + pub fn new>(model_path: P) -> Result { + Ok(Self { + backend: create_llamacpp_backend(model_path.as_ref().to_str().unwrap()), + }) + } +} impl Backend for TgiLlamaCppBakend { fn schedule( diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index bea7c06f..d25e3ca0 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -1,11 +1,16 @@ pub mod backend; -#[cxx::bridge(namespace = "huggingface::tgi::backends::llama::impl")] +#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")] mod ffi { unsafe extern "C++" { - include!("backends/llamacpp/csrc/backend.cpp"); + include!("backends/llamacpp/csrc/ffi.hpp"); /// Represent an instance of the llama.cpp backend instance on C++ side type LlamaCppBackendImpl; + + #[rust_name = "create_llamacpp_backend"] + fn CreateLlamaCppBackend( + engine_folder: &str, + ) -> UniquePtr; } }