mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Enable end to end CMake build
This commit is contained in:
parent
818162e0c2
commit
f8a1463915
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/devcontainer.json
Normal file
0
.devcontainer/devcontainer.json
Normal file
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3466,6 +3466,7 @@ dependencies = [
|
|||||||
"cxx-build",
|
"cxx-build",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
]
|
]
|
||||||
|
@ -7,21 +7,28 @@ include(FetchContent)
|
|||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
|
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "86-real;89-real;90-real" CACHE STRING "List of CUDA architectures to support")
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "native" CACHE STRING "List of CUDA architectures to support")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE PATH "Path where TensorRT libraries and headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE PATH "Path where TensorRT libraries and headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE PATH "Path where TensorRT headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE PATH "Path where TensorRT headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE PATH "Path where TensorRT libraries are located")
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE PATH "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_ROOT})
|
#### Unit Tests ####
|
||||||
message(FATAL_ERROR "TensorRT specified location: ${TGI_TRTLLM_BACKEND_TRT_ROOT} doesn't exist")
|
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||||
else ()
|
message(STATUS "Building tests")
|
||||||
if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
FetchContent_Declare(
|
||||||
message(FATAL_ERROR "TensorRT headers were not found at: ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}")
|
Catch2
|
||||||
endif ()
|
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||||
|
GIT_TAG v3.6.0
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(Catch2)
|
||||||
|
|
||||||
if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
|
||||||
message(FATAL_ERROR "TensorRT libraries were not found at: ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}")
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain)
|
||||||
endif ()
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||||
|
include(CTest)
|
||||||
|
include(Catch)
|
||||||
|
catch_discover_tests(tgi_trtllm_backend_tests)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
#### External dependencies ####
|
#### External dependencies ####
|
||||||
@ -34,23 +41,4 @@ target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
|||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||||
$<INSTALL_INTERFACE:include>
|
$<INSTALL_INTERFACE:include>
|
||||||
)
|
)
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC spdlog::spdlog tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
target_link_libraries(tgi_trtllm_backend_impl PUBLIC spdlog::spdlog tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
||||||
|
|
||||||
#### Unit Tests ####
|
|
||||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
|
||||||
message(STATUS "Building tests")
|
|
||||||
FetchContent_Declare(
|
|
||||||
Catch2
|
|
||||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
|
||||||
GIT_TAG v3.6.0
|
|
||||||
)
|
|
||||||
FetchContent_MakeAvailable(Catch2)
|
|
||||||
|
|
||||||
add_executable(tgi_trtllm_backend_tests)
|
|
||||||
target_link_libraries(tests PRIVATE Catch2::Catch2::Catch2WithMain)
|
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
|
||||||
include(CTest)
|
|
||||||
include(Catch)
|
|
||||||
catch_discover_tests(tests)
|
|
||||||
endif ()
|
|
@ -10,6 +10,7 @@ async-trait = "0.1.74"
|
|||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
|
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5.4", features = ["derive"] }
|
||||||
|
@ -5,17 +5,53 @@ use cxx_build::CFG;
|
|||||||
|
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
|
|
||||||
|
// fn build_tensort_llm<P: AsRef<Path>>(tensorrt_llm_root_dir: P, is_debug: bool) -> PathBuf {
|
||||||
|
// let build_wheel_path = tensorrt_llm_root_dir
|
||||||
|
// .as_ref()
|
||||||
|
// .join("/scripts")
|
||||||
|
// .join("build_wheel.py");
|
||||||
|
//
|
||||||
|
// let build_wheel_path_str = build_wheel_path.display().to_string();
|
||||||
|
// let mut build_wheel_args = vec![
|
||||||
|
// build_wheel_path_str.as_ref(),
|
||||||
|
// "--cpp_only",
|
||||||
|
// "--extra-cmake-vars BUILD_TESTS=OFF",
|
||||||
|
// "--extra-cmake-vars BUILD_BENCHMARKS=OFF",
|
||||||
|
// ];
|
||||||
|
//
|
||||||
|
// if is_debug {
|
||||||
|
// build_wheel_args.push("--fast_build");
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// let out = Command::new("python3")
|
||||||
|
// .args(build_wheel_args)
|
||||||
|
// .output()
|
||||||
|
// .expect("Failed to compile TensorRT-LLM");
|
||||||
|
// PathBuf::new().join(tensorrt_llm_root_dir)
|
||||||
|
// }
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
// Misc variables
|
||||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
let build_profile = env::var("PROFILE").unwrap();
|
let build_profile = env::var("PROFILE").unwrap();
|
||||||
|
let is_debug = match build_profile.as_ref() {
|
||||||
|
"debug" => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compile TensorRT-LLM (as of today, it cannot be compiled from CMake)
|
||||||
|
// let trtllm_path = build_tensort_llm(
|
||||||
|
// backend_path.join("build").join("_deps").join("trtllm-src"),
|
||||||
|
// is_debug,
|
||||||
|
// );
|
||||||
|
|
||||||
// Build the backend implementation through CMake
|
// Build the backend implementation through CMake
|
||||||
let backend_path = cmake::Config::new(".")
|
let backend_path = cmake::Config::new(".")
|
||||||
.uses_cxx11()
|
.uses_cxx11()
|
||||||
.generator("Ninja")
|
.generator("Ninja")
|
||||||
.profile(match build_profile.as_ref() {
|
.profile(match is_debug {
|
||||||
"release" => "Release",
|
true => "Debug",
|
||||||
_ => "Debug",
|
false => "Release",
|
||||||
})
|
})
|
||||||
.build_target("tgi_trtllm_backend_impl")
|
.build_target("tgi_trtllm_backend_impl")
|
||||||
.build();
|
.build();
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
set(USE_CXX11_ABI ON)
|
||||||
set(NVTX_DISABLE OFF)
|
set(NVTX_DISABLE OFF)
|
||||||
set(BUILD_PYT OFF)
|
set(BUILD_PYT OFF)
|
||||||
set(BUILD_PYBIND OFF)
|
set(BUILD_PYBIND OFF)
|
||||||
@ -8,6 +9,18 @@ set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
|||||||
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||||
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
||||||
|
|
||||||
|
#if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_ROOT})
|
||||||
|
# message(FATAL_ERROR "TensorRT specified location: ${TGI_TRTLLM_BACKEND_TRT_ROOT} doesn't exist")
|
||||||
|
#else ()
|
||||||
|
# if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||||
|
# message(FATAL_ERROR "TensorRT headers were not found at: ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}")
|
||||||
|
# endif ()
|
||||||
|
#
|
||||||
|
# if (NOT EXISTS ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||||
|
# message(FATAL_ERROR "TensorRT libraries were not found at: ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}")
|
||||||
|
# endif ()
|
||||||
|
#endif ()
|
||||||
|
|
||||||
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
|
||||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
@ -24,7 +37,7 @@ fetchcontent_declare(
|
|||||||
)
|
)
|
||||||
fetchcontent_makeavailable(trtllm)
|
fetchcontent_makeavailable(trtllm)
|
||||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||||
execute_process(COMMAND git lfs init WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||||
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||||
add_subdirectory("${trtllm_SOURCE_DIR}/cpp")
|
add_subdirectory("${trtllm_SOURCE_DIR}/cpp")
|
||||||
include_directories("${trtllm_SOURCE_DIR}/cpp/include")
|
include_directories("${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
|
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
@ -6,19 +6,55 @@
|
|||||||
#define TGI_TRTLLM_BACKEND_H
|
#define TGI_TRTLLM_BACKEND_H
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <span>
|
||||||
|
|
||||||
//#include <tensorrt_llm/runtime/common.h>
|
#include <tensorrt_llm/runtime/common.h>
|
||||||
//#include <tensorrt_llm/executor/executor.h>
|
#include <tensorrt_llm/executor/executor.h>
|
||||||
//
|
|
||||||
//namespace tle = tensorrt_llm::executor;
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
class TensorRtLlmBackend {
|
class TensorRtLlmBackend {
|
||||||
private:
|
private:
|
||||||
// tle::Executor executor;
|
tle::Executor executor;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TensorRtLlmBackend(const std::filesystem::path &engineFolder);
|
explicit TensorRtLlmBackend(const std::filesystem::path &engineFolder);
|
||||||
|
|
||||||
|
/***
|
||||||
|
* Indicate if the backend is ready to accept incoming request
|
||||||
|
* @return true if ready, false otherwise
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bool IsReady() const {
|
||||||
|
return executor.canEnqueueRequests();
|
||||||
|
}
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param minLength
|
||||||
|
* @param repetitionPenalty
|
||||||
|
* @param frequencePenalty
|
||||||
|
* @param seed
|
||||||
|
* @param nTopTokens
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard]] tle::IdType Submit(
|
||||||
|
std::vector<tle::TokenIdType> &tokens,
|
||||||
|
int32_t maxNewTokens,
|
||||||
|
float_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
int32_t minLength,
|
||||||
|
std::optional<float_t> repetitionPenalty = std::nullopt,
|
||||||
|
std::optional<float_t> frequencePenalty = std::nullopt,
|
||||||
|
std::optional<uint32_t> seed = std::nullopt,
|
||||||
|
std::optional<uint32_t> nTopTokens = std::nullopt
|
||||||
|
);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,49 @@
|
|||||||
|
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
|
|
||||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(const std::filesystem::path &engineFolder) {
|
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(const std::filesystem::path &engineFolder)
|
||||||
|
: executor(engineFolder, tle::ModelType::kDECODER_ONLY, tle::ExecutorConfig{}) {
|
||||||
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
|
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
|
std::vector<tle::TokenIdType> &tokens,
|
||||||
|
const int32_t maxNewTokens,
|
||||||
|
const float_t topK,
|
||||||
|
const float_t topP,
|
||||||
|
const float_t temperature,
|
||||||
|
const int32_t minLength,
|
||||||
|
const std::optional<float_t> repetitionPenalty,
|
||||||
|
const std::optional<float_t> frequencePenalty,
|
||||||
|
const std::optional<uint32_t> seed,
|
||||||
|
const std::optional<uint32_t> nTopTokens
|
||||||
|
) {
|
||||||
|
if (IsReady()) {
|
||||||
|
spdlog::debug(
|
||||||
|
"Submitting inference over {:d} tokens to the executor {:d}",
|
||||||
|
tokens.size(),
|
||||||
|
executor.getLatestIterationStats().back().numActiveRequests
|
||||||
|
);
|
||||||
|
|
||||||
|
const auto sampling = tle::SamplingConfig{
|
||||||
|
1,
|
||||||
|
topK,
|
||||||
|
topP,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
seed,
|
||||||
|
temperature,
|
||||||
|
minLength,
|
||||||
|
std::nullopt,
|
||||||
|
repetitionPenalty.value_or(0.0),
|
||||||
|
std::nullopt,
|
||||||
|
frequencePenalty.value_or(1.0),
|
||||||
|
};
|
||||||
|
const auto output = tle::OutputConfig{false, false, nTopTokens.value_or(1) > 1};
|
||||||
|
const auto request = tle::Request{std::move(tokens), maxNewTokens, true, sampling, output};
|
||||||
|
|
||||||
|
return executor.enqueueRequest(request);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
@ -2,6 +2,8 @@ use std::path::Path;
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
@ -11,6 +13,7 @@ use crate::errors::TensorRtLlmBackendError;
|
|||||||
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend};
|
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend};
|
||||||
|
|
||||||
pub struct TrtLLmBackend {
|
pub struct TrtLLmBackend {
|
||||||
|
tokenizer: Tokenizer,
|
||||||
inner: UniquePtr<TensorRtLlmBackend>,
|
inner: UniquePtr<TensorRtLlmBackend>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -18,11 +21,14 @@ unsafe impl Sync for TrtLLmBackend {}
|
|||||||
unsafe impl Send for TrtLLmBackend {}
|
unsafe impl Send for TrtLLmBackend {}
|
||||||
|
|
||||||
impl TrtLLmBackend {
|
impl TrtLLmBackend {
|
||||||
pub fn new<P: AsRef<Path>>(engine_folder: P) -> Result<Self, TensorRtLlmBackendError> {
|
pub fn new<P: AsRef<Path>>(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
engine_folder: P,
|
||||||
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
let engine_folder = engine_folder.as_ref();
|
let engine_folder = engine_folder.as_ref();
|
||||||
let inner = create_trtllm_backend(engine_folder.to_str().unwrap());
|
let inner = create_trtllm_backend(engine_folder.to_str().unwrap());
|
||||||
|
|
||||||
Ok(Self { inner })
|
Ok(Self { tokenizer, inner })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,12 +36,15 @@ impl TrtLLmBackend {
|
|||||||
impl Backend for TrtLLmBackend {
|
impl Backend for TrtLLmBackend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
_request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
todo!()
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
|
let request_id = self.inner.submit();
|
||||||
|
|
||||||
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
async fn health(&self, _current_health: bool) -> bool {
|
||||||
true
|
self.inner.is_ready()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ namespace huggingface::tgi::backends {
|
|||||||
*/
|
*/
|
||||||
std::unique_ptr<TensorRtLlmBackend> create_trtllm_backend(rust::Str engineFolder) {
|
std::unique_ptr<TensorRtLlmBackend> create_trtllm_backend(rust::Str engineFolder) {
|
||||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||||
return std::make_unique<TensorRtLlmBackend>(enginePath);
|
return std::make_unique<TensorRtLlmBackend>(std::move(enginePath));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -11,5 +11,11 @@ mod ffi {
|
|||||||
type TensorRtLlmBackend;
|
type TensorRtLlmBackend;
|
||||||
|
|
||||||
fn create_trtllm_backend(engine_folder: &str) -> UniquePtr<TensorRtLlmBackend>;
|
fn create_trtllm_backend(engine_folder: &str) -> UniquePtr<TensorRtLlmBackend>;
|
||||||
|
|
||||||
|
#[rust_name = "is_ready"]
|
||||||
|
fn IsReady(&self) -> bool;
|
||||||
|
|
||||||
|
#[rust_name = "submit"]
|
||||||
|
fn Submit(&self) -> u64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
|
|
||||||
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
|
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
@ -109,7 +112,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
let backend = TrtLLmBackend::new(model_id)?;
|
let tokenizer = Tokenizer::from_pretrained(
|
||||||
|
tokenizer_name.clone(),
|
||||||
|
Some(FromPretrainedParameters {
|
||||||
|
revision: revision.clone().unwrap_or(String::from("main")),
|
||||||
|
user_agent: HashMap::new(),
|
||||||
|
auth_token,
|
||||||
|
}),
|
||||||
|
)?;
|
||||||
|
let backend = TrtLLmBackend::new(tokenizer, model_id)?;
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
9
backends/trtllm/tests/infer_test.cpp
Normal file
9
backends/trtllm/tests/infer_test.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 7/2/24.
|
||||||
|
//
|
||||||
|
#include <catch2/catch_all.hpp>
|
||||||
|
#include "../include/backend.h"
|
||||||
|
|
||||||
|
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
|
||||||
|
huggingface::tgi::backends::TensorRtLlmBackend backend("fixtures/engines/llama3-8b-instruct.engine");
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user