mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-07 01:42:08 +00:00
feat(backend): wip Rust binding
This commit is contained in:
parent
f9c248657d
commit
355d8a55b4
@ -11,6 +11,13 @@ set(LLAMA_CPP_TARGET_CUDA_ARCHS "75-real;80-real;86-real;89-real;90-real" CACHE
|
|||||||
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
option(LLAMA_CPP_BUILD_OFFLINE_RUNNER "Flag to build the standalone c++ backend runner")
|
||||||
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
option(LLAMA_CPP_BUILD_CUDA "Flag to build CUDA enabled inference through llama.cpp")
|
||||||
|
|
||||||
|
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||||
|
message(STATUS "Targeting libc++")
|
||||||
|
set(CMAKE_CXX_FLAGS -stdlib=libc++ ${CMAKE_CXX_FLAGS})
|
||||||
|
else()
|
||||||
|
message(STATUS "Not using libc++ ${CMAKE_CXX_COMPILER_ID} ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Add dependencies
|
# Add dependencies
|
||||||
include(cmake/fmt.cmake)
|
include(cmake/fmt.cmake)
|
||||||
include(cmake/spdlog.cmake)
|
include(cmake/spdlog.cmake)
|
||||||
|
@ -59,18 +59,20 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
|
|||||||
CFG.include_prefix = "backends/llamacpp";
|
CFG.include_prefix = "backends/llamacpp";
|
||||||
cxx_build::bridge("src/lib.rs")
|
cxx_build::bridge("src/lib.rs")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
|
.std("c++23")
|
||||||
.include(deps_folder.join("fmt-src").join("include"))
|
.include(deps_folder.join("fmt-src").join("include"))
|
||||||
.include(deps_folder.join("spdlog-src").join("include"))
|
.include(deps_folder.join("spdlog-src").join("include"))
|
||||||
.include(deps_folder.join("llama-src").join("common"))
|
.include(deps_folder.join("llama-src").join("common"))
|
||||||
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
||||||
.include(deps_folder.join("llama-src").join("include"))
|
.include(deps_folder.join("llama-src").join("include"))
|
||||||
.file("csrc/backend.cpp")
|
.include("csrc/backend.hpp")
|
||||||
.std("c++23")
|
.file("csrc/ffi.cpp")
|
||||||
.compile(CMAKE_LLAMA_CPP_TARGET);
|
.compile(CMAKE_LLAMA_CPP_TARGET);
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
println!("cargo:rerun-if-changed=csrc/backend.hpp");
|
||||||
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
println!("cargo:rerun-if-changed=csrc/backend.cpp");
|
||||||
|
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <expected>
|
#include <expected>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <span>
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
|
||||||
#define LLAMA_SUCCESS(x) x == 0
|
#define LLAMA_SUCCESS(x) x == 0
|
||||||
|
19
backends/llamacpp/csrc/ffi.hpp
Normal file
19
backends/llamacpp/csrc/ffi.hpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 10/23/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
||||||
|
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
||||||
|
|
||||||
|
#include "backend.hpp"
|
||||||
|
//#include "backends/llamacpp/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::llama {
|
||||||
|
class LlamaCppBackendImpl {
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
@ -1,8 +1,21 @@
|
|||||||
|
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
|
||||||
|
use cxx::UniquePtr;
|
||||||
|
use std::path::Path;
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
|
||||||
pub struct TgiLlamaCppBakend {}
|
pub struct TgiLlamaCppBakend {
|
||||||
|
backend: UniquePtr<LlamaCppBackendImpl>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TgiLlamaCppBakend {
|
||||||
|
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, ()> {
|
||||||
|
Ok(Self {
|
||||||
|
backend: create_llamacpp_backend(model_path.as_ref().to_str().unwrap()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Backend for TgiLlamaCppBakend {
|
impl Backend for TgiLlamaCppBakend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama::impl")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llama")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/llamacpp/csrc/backend.cpp");
|
include!("backends/llamacpp/csrc/ffi.hpp");
|
||||||
|
|
||||||
/// Represent an instance of the llama.cpp backend instance on C++ side
|
/// Represent an instance of the llama.cpp backend instance on C++ side
|
||||||
type LlamaCppBackendImpl;
|
type LlamaCppBackendImpl;
|
||||||
|
|
||||||
|
#[rust_name = "create_llamacpp_backend"]
|
||||||
|
fn CreateLlamaCppBackend(
|
||||||
|
engine_folder: &str,
|
||||||
|
) -> UniquePtr<LlamaCppBackendImpl>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user