mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-08 19:02:07 +00:00
feat(llamacpp): enable cuda
This commit is contained in:
parent
fa89d1e613
commit
05ad684676
@ -2,6 +2,7 @@ use cxx_build::CFG;
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
|
||||||
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
@ -20,6 +21,10 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
|
|||||||
.map(|val| PathBuf::from(val))
|
.map(|val| PathBuf::from(val))
|
||||||
.unwrap_or(out_dir.join("dist"));
|
.unwrap_or(out_dir.join("dist"));
|
||||||
|
|
||||||
|
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
|
||||||
|
let cuda_archs =
|
||||||
|
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
|
||||||
|
|
||||||
let _ = cmake::Config::new(".")
|
let _ = cmake::Config::new(".")
|
||||||
.uses_cxx11()
|
.uses_cxx11()
|
||||||
.generator("Ninja")
|
.generator("Ninja")
|
||||||
@ -29,9 +34,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
|
|||||||
})
|
})
|
||||||
.env("OPT_LEVEL", opt_level)
|
.env("OPT_LEVEL", opt_level)
|
||||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||||
// .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
.define("LLAMA_CPP_BUILD_CUDA", build_cuda)
|
||||||
// .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
.define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
|
||||||
// .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// Additional transitive CMake dependencies
|
// Additional transitive CMake dependencies
|
||||||
@ -61,7 +65,7 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
|
|||||||
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
||||||
.include(deps_folder.join("llama-src").join("include"))
|
.include(deps_folder.join("llama-src").join("include"))
|
||||||
.file("csrc/backend.cpp")
|
.file("csrc/backend.cpp")
|
||||||
.std("c++20")
|
.std("c++23")
|
||||||
.compile(CMAKE_LLAMA_CPP_TARGET);
|
.compile(CMAKE_LLAMA_CPP_TARGET);
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
@ -10,14 +10,15 @@
|
|||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llama {
|
||||||
|
|
||||||
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root) {
|
std::unique_ptr<huggingface::tgi::backends::llama::TgiLlamaCppBackend>
|
||||||
|
CreateLlamaCppBackend(std::string_view root) {
|
||||||
SPDLOG_INFO(FMT_STRING("Loading model from {}"), root);
|
SPDLOG_INFO(FMT_STRING("Loading model from {}"), root);
|
||||||
gpt_init();
|
gpt_init();
|
||||||
|
|
||||||
// Fake argv
|
// Fake argv
|
||||||
std::vector<std::string_view> args = {"tgi_llama_cpp_backend", "--model", root};
|
std::vector<std::string_view> args = {"tgi_llama_cpp_backend", "--model", root};
|
||||||
std::vector<char*> argv;
|
std::vector<char *> argv;
|
||||||
for(const auto& arg : args) {
|
for (const auto &arg: args) {
|
||||||
argv.push_back(const_cast<char *>(arg.data()));
|
argv.push_back(const_cast<char *>(arg.data()));
|
||||||
}
|
}
|
||||||
argv.push_back(nullptr);
|
argv.push_back(nullptr);
|
||||||
@ -39,35 +40,39 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
auto loras = result.lora_adapters;
|
auto loras = result.lora_adapters;
|
||||||
|
|
||||||
// Make sure everything is correctly initialized
|
// Make sure everything is correctly initialized
|
||||||
if(model == nullptr)
|
if (model == nullptr)
|
||||||
throw std::runtime_error(fmt::format("Failed to load model from {}", root));
|
throw std::runtime_error(fmt::format("Failed to load model from {}", root));
|
||||||
|
|
||||||
return std::make_unique<TgiLlamaCppBackend>(model, context);
|
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
|
||||||
: model(model), ctx(ctx), batch()
|
llama_context *const ctx)
|
||||||
{
|
: model(model), ctx(ctx), batch() {
|
||||||
char modelName[128];
|
char modelName[128];
|
||||||
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
||||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||||
}
|
}
|
||||||
|
|
||||||
TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||||
if(model)
|
if (model) {
|
||||||
{
|
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp model");
|
SPDLOG_DEBUG("Freeing llama.cpp model");
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(ctx)
|
if (ctx) {
|
||||||
{
|
|
||||||
SPDLOG_DEBUG("Freeing llama.cpp context");
|
SPDLOG_DEBUG("Freeing llama.cpp context");
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TgiLlamaCppBackend::schedule() {
|
void huggingface::tgi::backends::llama::TgiLlamaCppBackend::schedule() {
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace impl {
|
||||||
|
class LlamaCppBackendImpl {
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user