feat(llamacpp): enable cuda

This commit is contained in:
Morgan Funtowicz 2024-10-21 09:14:51 +02:00
parent fa89d1e613
commit 05ad684676
2 changed files with 27 additions and 18 deletions

View File

@ -2,6 +2,7 @@ use cxx_build::CFG;
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl"; const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const MPI_REQUIRED_VERSION: &str = "4.1"; const MPI_REQUIRED_VERSION: &str = "4.1";
@ -20,6 +21,10 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
.map(|val| PathBuf::from(val)) .map(|val| PathBuf::from(val))
.unwrap_or(out_dir.join("dist")); .unwrap_or(out_dir.join("dist"));
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
let cuda_archs =
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
let _ = cmake::Config::new(".") let _ = cmake::Config::new(".")
.uses_cxx11() .uses_cxx11()
.generator("Ninja") .generator("Ninja")
@ -29,9 +34,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
}) })
.env("OPT_LEVEL", opt_level) .env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path) .define("CMAKE_INSTALL_PREFIX", &install_path)
// .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define("LLAMA_CPP_BUILD_CUDA", build_cuda)
// .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) .define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
// .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
.build(); .build();
// Additional transitive CMake dependencies // Additional transitive CMake dependencies
@ -61,7 +65,7 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
.include(deps_folder.join("llama-src").join("ggml").join("include")) .include(deps_folder.join("llama-src").join("ggml").join("include"))
.include(deps_folder.join("llama-src").join("include")) .include(deps_folder.join("llama-src").join("include"))
.file("csrc/backend.cpp") .file("csrc/backend.cpp")
.std("c++20") .std("c++23")
.compile(CMAKE_LLAMA_CPP_TARGET); .compile(CMAKE_LLAMA_CPP_TARGET);
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");

View File

@ -10,14 +10,15 @@
namespace huggingface::tgi::backends::llama { namespace huggingface::tgi::backends::llama {
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root) { std::unique_ptr<huggingface::tgi::backends::llama::TgiLlamaCppBackend>
CreateLlamaCppBackend(std::string_view root) {
SPDLOG_INFO(FMT_STRING("Loading model from {}"), root); SPDLOG_INFO(FMT_STRING("Loading model from {}"), root);
gpt_init(); gpt_init();
// Fake argv // Fake argv
std::vector<std::string_view> args = {"tgi_llama_cpp_backend", "--model", root}; std::vector<std::string_view> args = {"tgi_llama_cpp_backend", "--model", root};
std::vector<char*> argv; std::vector<char *> argv;
for(const auto& arg : args) { for (const auto &arg: args) {
argv.push_back(const_cast<char *>(arg.data())); argv.push_back(const_cast<char *>(arg.data()));
} }
argv.push_back(nullptr); argv.push_back(nullptr);
@ -39,35 +40,39 @@ namespace huggingface::tgi::backends::llama {
auto loras = result.lora_adapters; auto loras = result.lora_adapters;
// Make sure everything is correctly initialized // Make sure everything is correctly initialized
if(model == nullptr) if (model == nullptr)
throw std::runtime_error(fmt::format("Failed to load model from {}", root)); throw std::runtime_error(fmt::format("Failed to load model from {}", root));
return std::make_unique<TgiLlamaCppBackend>(model, context); return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
} }
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx) huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
: model(model), ctx(ctx), batch() llama_context *const ctx)
{ : model(model), ctx(ctx), batch() {
char modelName[128]; char modelName[128];
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName)); llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName)); SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
} }
TgiLlamaCppBackend::~TgiLlamaCppBackend() { huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
if(model) if (model) {
{
SPDLOG_DEBUG("Freeing llama.cpp model"); SPDLOG_DEBUG("Freeing llama.cpp model");
llama_free_model(model); llama_free_model(model);
} }
if(ctx) if (ctx) {
{
SPDLOG_DEBUG("Freeing llama.cpp context"); SPDLOG_DEBUG("Freeing llama.cpp context");
llama_free(ctx); llama_free(ctx);
} }
} }
void TgiLlamaCppBackend::schedule() { void huggingface::tgi::backends::llama::TgiLlamaCppBackend::schedule() {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
} }
namespace impl {
class LlamaCppBackendImpl {
};
}
} }