mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-08 10:52:14 +00:00
feat(llamacpp): enable cuda
This commit is contained in:
parent
fa89d1e613
commit
05ad684676
@ -2,6 +2,7 @@ use cxx_build::CFG;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS: &str = "75-real;80-real;86-real;89-real;90-real";
|
||||
const CMAKE_LLAMA_CPP_TARGET: &str = "tgi_llama_cpp_backend_impl";
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||
@ -20,6 +21,10 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
|
||||
.map(|val| PathBuf::from(val))
|
||||
.unwrap_or(out_dir.join("dist"));
|
||||
|
||||
let build_cuda = option_env!("LLAMA_CPP_BUILD_CUDA").unwrap_or("OFF");
|
||||
let cuda_archs =
|
||||
option_env!("LLAMA_CPP_TARGET_CUDA_ARCHS").unwrap_or(CMAKE_LLAMA_CPP_DEFAULT_CUDA_ARCHS);
|
||||
|
||||
let _ = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
@ -29,9 +34,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> PathBuf
|
||||
})
|
||||
.env("OPT_LEVEL", opt_level)
|
||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||
// .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||
// .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||
// .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||
.define("LLAMA_CPP_BUILD_CUDA", build_cuda)
|
||||
.define("LLAMA_CPP_TARGET_CUDA_ARCHS", cuda_archs)
|
||||
.build();
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
@ -61,7 +65,7 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||
.include(deps_folder.join("llama-src").join("ggml").join("include"))
|
||||
.include(deps_folder.join("llama-src").join("include"))
|
||||
.file("csrc/backend.cpp")
|
||||
.std("c++20")
|
||||
.std("c++23")
|
||||
.compile(CMAKE_LLAMA_CPP_TARGET);
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
|
@ -10,14 +10,15 @@
|
||||
|
||||
namespace huggingface::tgi::backends::llama {
|
||||
|
||||
std::unique_ptr<TgiLlamaCppBackend> CreateLlamaCppBackend(std::string_view root) {
|
||||
std::unique_ptr<huggingface::tgi::backends::llama::TgiLlamaCppBackend>
|
||||
CreateLlamaCppBackend(std::string_view root) {
|
||||
SPDLOG_INFO(FMT_STRING("Loading model from {}"), root);
|
||||
gpt_init();
|
||||
|
||||
// Fake argv
|
||||
std::vector<std::string_view> args = {"tgi_llama_cpp_backend", "--model", root};
|
||||
std::vector<char*> argv;
|
||||
for(const auto& arg : args) {
|
||||
std::vector<char *> argv;
|
||||
for (const auto &arg: args) {
|
||||
argv.push_back(const_cast<char *>(arg.data()));
|
||||
}
|
||||
argv.push_back(nullptr);
|
||||
@ -39,35 +40,39 @@ namespace huggingface::tgi::backends::llama {
|
||||
auto loras = result.lora_adapters;
|
||||
|
||||
// Make sure everything is correctly initialized
|
||||
if(model == nullptr)
|
||||
if (model == nullptr)
|
||||
throw std::runtime_error(fmt::format("Failed to load model from {}", root));
|
||||
|
||||
return std::make_unique<TgiLlamaCppBackend>(model, context);
|
||||
return std::make_unique<huggingface::tgi::backends::llama::TgiLlamaCppBackend>(model, context);
|
||||
}
|
||||
|
||||
TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model, llama_context *const ctx)
|
||||
: model(model), ctx(ctx), batch()
|
||||
{
|
||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
|
||||
llama_context *const ctx)
|
||||
: model(model), ctx(ctx), batch() {
|
||||
char modelName[128];
|
||||
llama_model_meta_val_str(model, "general.name", modelName, sizeof(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
}
|
||||
|
||||
TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||
if(model)
|
||||
{
|
||||
huggingface::tgi::backends::llama::TgiLlamaCppBackend::~TgiLlamaCppBackend() {
|
||||
if (model) {
|
||||
SPDLOG_DEBUG("Freeing llama.cpp model");
|
||||
llama_free_model(model);
|
||||
}
|
||||
|
||||
if(ctx)
|
||||
{
|
||||
if (ctx) {
|
||||
SPDLOG_DEBUG("Freeing llama.cpp context");
|
||||
llama_free(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
void TgiLlamaCppBackend::schedule() {
|
||||
void huggingface::tgi::backends::llama::TgiLlamaCppBackend::schedule() {
|
||||
std::vector<llama_token> tokens;
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
class LlamaCppBackendImpl {
|
||||
|
||||
};
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user