feat(trtllm): expose finish reason to Rust

This commit is contained in:
Morgan Funtowicz 2024-12-10 16:51:22 +01:00
parent f0cd4742c2
commit 60059b6968
10 changed files with 302 additions and 175 deletions

View File

@ -1,11 +1,19 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug") if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
find_program(CCACHE_EXECUTABLE "ccache") find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE) if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache") message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE) set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
endif () endif ()
else ()
find_program(CCACHE_EXECUTABLE ${CMAKE_CXX_COMPILER_LAUNCHER})
message(STATUS "Using user specified cmake cxx compiler launcher: ${CMAKE_CXX_COMPILER_LAUNCHER}")
set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
endif () endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
@ -21,28 +29,37 @@ include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(MPI REQUIRED)
#### External dependencies #### #### External dependencies ####
include(cmake/json.cmake) include(cmake/json.cmake)
include(cmake/spdlog.cmake) include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake) include(cmake/trtllm.cmake)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(TGI_TRTLLM_BACKEND_DEBUG ON)
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1) add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
endif() endif ()
if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
message(STATUS "Using lld linker")
add_link_options("-fuse-ld=lld")
endif ()
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO) check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
if(${COMPILER_SUPPORT_WARNING_ON_NVRO}) if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro") message(STATUS "Enabling non-NVRO detection")
endif() target_compile_options(tgi_trtllm_backend_impl "-Werror -Wnvro")
endif ()
# Let's build TRTLLM as part of CMake # Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
@ -55,21 +72,20 @@ add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp cs
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc> $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc>
# $<INSTALL_INTERFACE:csrc> # $<INSTALL_INTERFACE:csrc>
) )
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
else()
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif ()
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
endif ()
#### Unit Tests #### #### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
@ -85,18 +101,13 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (CMAKE_BUILD_TYPE MATCHES "Debug")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=undefined")
else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=undefined")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif ()
if(CMAKE_BUILD_TYPE MATCHES "Debug")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address) target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address)
endif() endif ()
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest) include(CTest)

View File

@ -15,9 +15,8 @@ const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
// Dependencies // Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 4] = [
("dylib", "tensorrt_llm"), ("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"), ("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"), ("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"), ("dylib", "decoder_attention"),
@ -32,6 +31,58 @@ macro_rules! probe {
}; };
} }
fn get_compiler_flag(
switch: bool,
true_case: &'static str,
false_case: &'static str,
) -> &'static str {
match switch {
true => true_case,
false => false_case,
}
}
#[cfg(target_arch = "x86_64")]
fn get_system_install_path(install_path: &PathBuf) -> PathBuf {
install_path.join("lib64")
}
#[cfg(not(target_arch = "x86_64"))]
fn get_system_install_path(install_path: &PathBuf) -> PathBuf {
install_path.join("lib")
}
fn get_library_architecture() -> &'static str {
let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
match os.as_str() {
"linux" => {
if env != "gnu" {
panic!("unsupported linux ABI {env}, only 'gnu' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-linux-gnu",
"aarch64" => "aarch64-linux-gnu",
_ => panic!("unsupported linux architecture {arch}"),
}
}
"windows" => {
if env != "msvc" {
panic!("unsupported windows ABI {env}, only 'msvc' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-windows-msvc",
_ => panic!("unsupported windows architecture {arch}"),
}
}
_ => panic!("unsupported OS {os}"),
}
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake // Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
@ -44,7 +95,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
let mut config = cmake::Config::new("."); let mut config = cmake::Config::new(".");
config.uses_cxx11() config
.uses_cxx11()
.generator("Ninja") .generator("Ninja")
.profile(match is_debug { .profile(match is_debug {
true => "Debug", true => "Debug",
@ -53,16 +105,28 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.env("OPT_LEVEL", opt_level) .env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path) .define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("Python3_ROOT_DIR", "../venv") .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path); .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
// Allow to override which Python to use ... if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
if let Some(python3) = option_env!("Python3_EXECUTABLE") { config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
config.define("Python3_EXECUTABLE", python3); }
}
config.build(); if let Some(cxx_compiler_launcher) = option_env!("CMAKE_CXX_COMPILER_LAUNCHER") {
config.define("CMAKE_CXX_COMPILER_LAUNCHER", cxx_compiler_launcher);
}
// Allow to override which Python to use ...
if let Some(python3) = option_env!("Python3_EXECUTABLE") {
config.define("Python3_EXECUTABLE", python3);
}
config.build();
// Additional transitive CMake dependencies // Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps"); let deps_folder = out_dir.join("build").join("_deps");
@ -77,7 +141,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
// Emit linkage information from the artifacts we just built // Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib");
let install_lib_path = get_system_install_path(&install_path);
println!( println!(
r"cargo:warning=Adding link search path: {}", r"cargo:warning=Adding link search path: {}",
@ -89,11 +154,6 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm"; CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
@ -105,7 +165,10 @@ fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
.include("/usr/local/tensorrt/include") .include("/usr/local/tensorrt/include")
.include("csrc/") .include("csrc/")
.file("csrc/ffi.hpp") .file("csrc/ffi.hpp")
.define("TGI_TRTLLM_BACKEND_DEBUG", ndebug) .define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.compile("tgi_trtllm_backend"); .compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");

View File

@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined # Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else () else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif () endif ()
fetchcontent_declare( fetchcontent_declare(
spdlog spdlog
# DOWNLOAD_EXTRACT_TIMESTAMP # DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
) )
fetchcontent_makeavailable(spdlog) fetchcontent_makeavailable(spdlog)

View File

@ -14,11 +14,13 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
set(ENABLE_UCX OFF) set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON) set(FAST_BUILD ON)
set(NVTX_DISABLE OFF) set(NVTX_DISABLE ON)
set(INDEX_RANGE_CHECK ON)
else () else ()
set(FAST_BUILD OFF) set(FAST_BUILD OFF)
set(FAST_MATH ON) set(FAST_MATH ON)
set(NVTX_DISABLE ON) set(NVTX_DISABLE OFF)
set(INDEX_RANGE_CHECK OFF)
endif () endif ()
find_package(Python3 REQUIRED Interpreter) find_package(Python3 REQUIRED Interpreter)

View File

@ -1,7 +1,6 @@
#include <ranges> #include <ranges>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include "backend.hpp" #include "backend.hpp"
#include "hardware.hpp" #include "hardware.hpp"
@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm {
if (world_size > 1) { if (world_size > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR; mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr, true); orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
true);
} else { } else {
SPDLOG_INFO("Detected single engine deployment, using leader mode"); SPDLOG_INFO("Detected single engine deployment, using leader mode");
} }
@ -44,21 +44,22 @@ namespace huggingface::tgi::backends::trtllm {
} }
backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
: workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
size_t backend_t::num_tokens_ready() const noexcept { size_t backend_t::num_tokens_ready() const noexcept {
return executor_.getNumResponsesReady(); return executor_.getNumResponsesReady();
} }
std::expected<request_id_t, backend_error_t> std::expected<request_id_t, backend_error_t>
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept { backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params); const sampling_params_t s_params) noexcept {
return executor_.enqueueRequest(tle::Request { SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
return executor_.enqueueRequest(tle::Request{
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
static_cast<tle::SizeType32>(generation_params.max_new_tokens), static_cast<tle::SizeType32>(g_params.max_new_tokens),
true, true,
(tle::SamplingConfig) sampling_params, (tle::SamplingConfig) s_params,
tle::OutputConfig { /* returnLogProbs= */ true }, tle::OutputConfig{ /* returnLogProbs= */ true},
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,

View File

@ -28,20 +28,61 @@ namespace huggingface::tgi::backends::trtllm {
#include "backends/trtllm/src/lib.rs.h" #include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends::trtllm { namespace huggingface::tgi::backends::trtllm {
std::once_flag backend_initialized_flag; std::once_flag backend_initialized_flag;
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
switch (reason) {
case tle::FinishReason::kNOT_FINISHED:
return finish_reason_t::kNOT_FINISHED;
case tle::FinishReason::kSTOP_WORDS:
return finish_reason_t::kSTOP_WORDS;
case tle::FinishReason::kEND_ID:
return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH;
default:
std::unreachable();
}
}
static auto as_generation_step = [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
return generation_step_t{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
finish_reason_t::kNOT_FINISHED,
true,
std::move(r.getErrorMsg())
};
}
};
class tensorrt_llm_backend_t { class tensorrt_llm_backend_t {
private: private:
backend_t inner_; backend_t inner_;
public: public:
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
: inner_(engine_folder, executor_worker_path) {} : inner_(engine_folder, executor_worker_path) {}
size_t num_tokens_ready() const noexcept { size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
return inner_.num_tokens_ready();
}
request_id_t submit( request_id_t submit(
rust::Slice<const uint32_t> tokens, rust::Slice<const uint32_t> tokens,
@ -59,13 +100,13 @@ namespace huggingface::tgi::backends::trtllm {
// Submit the request to the executor and get back a potential request_id used to track request status // Submit the request to the executor and get back a potential request_id used to track request status
const auto signed_tokens = std::vector<int32_t>(tokens.begin(), tokens.end()); const auto signed_tokens = std::vector<int32_t>(tokens.begin(), tokens.end());
const auto maybe_request_id = inner_.submit( const auto maybe_request_id = inner_.submit(
signed_tokens, signed_tokens,
{max_new_tokens}, {max_new_tokens},
{top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
); );
// If we do have a value, let's return the request_id // If we do have a value, let's return the request_id
if(maybe_request_id.has_value()) [[likely]] { if (maybe_request_id.has_value()) [[likely]] {
return *maybe_request_id; return *maybe_request_id;
} else { } else {
SPDLOG_WARN("[FFI] Failed to submit request to the executor"); SPDLOG_WARN("[FFI] Failed to submit request to the executor");
@ -74,61 +115,45 @@ namespace huggingface::tgi::backends::trtllm {
} }
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept { std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {
if(num_tokens_ready() > 0) [[likely]] { if (num_tokens_ready() > 0) [[likely]] {
const auto responses = inner_.pull_tokens(); const auto responses = inner_.pull_tokens();
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
// Transform tle::Response to GenerationStep
auto steps = std::make_unique<std::vector<generation_step_t>>(); // Transform tle::Response to generation_step_t
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { #ifdef __cpp_lib_ranges_to_container
const auto reqId = r.getRequestId(); auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
if (!r.hasError()) [[likely]] { #else
const auto result = r.getResult(); auto steps = std::vector<generation_step_t>();
return generation_step_t{ steps.reserve(responses.size());
reqId, std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
static_cast<uint32_t>(result.outputTokenIds[0][0]), #endif
result.logProbs.value()[0][0], return std::make_unique<std::vector<generation_step_t>>(steps);
result.isFinal,
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
return steps;
} else { } else {
return std::make_unique<std::vector<generation_step_t>>(); return std::make_unique<std::vector<generation_step_t>>();
} }
} }
void cancel(request_id_t requestId) noexcept { void cancel(request_id_t request_id) noexcept {
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
inner_.cancel(requestId); inner_.cancel(request_id);
} }
}; };
void initialize_logging() { void initialize_logging() {
#ifndef TGI_TRTLLM_BACKEND_DEBUG #ifndef TGI_TRTLLM_BACKEND_DEBUG
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) { if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR); std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c); return std::tolower(c);
}); });
if (log_level == "debug") if (log_level == "debug")
spdlog::set_level(spdlog::level::debug); spdlog::set_level(spdlog::level::debug);
else else
spdlog::set_level(spdlog::level::info); spdlog::set_level(spdlog::level::info);
} }
#else #else
spdlog::set_level(spdlog::level::debug); spdlog::set_level(spdlog::level::debug);
#endif #endif
@ -151,11 +176,14 @@ namespace huggingface::tgi::backends::trtllm {
} }
} }
std::unique_ptr<tensorrt_llm_backend_t> create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { std::unique_ptr<tensorrt_llm_backend_t>
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique<tensorrt_llm_backend_t>( return std::make_unique<tensorrt_llm_backend_t>(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format) std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format)
); );
} }
} }

View File

@ -6,6 +6,26 @@ mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")] #[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi { mod ffi {
#[cxx_name = "finish_reason_t"]
#[derive(Debug, Clone, Copy)]
pub enum FinishReason {
/// The request is not finished.
#[cxx_name = "kNOT_FINISHED"]
NotFinished = 0u8,
/// The request finished because the end id was generated.
#[cxx_name = "kEND_ID"]
EndTokenId = 1u8,
/// The request finished because a stop word was generated.
#[cxx_name = "kSTOP_WORDS"]
StopWords = 2u8,
/// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"]
MaxLength = 3u8,
}
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[cxx_name = "generation_step_t"] #[cxx_name = "generation_step_t"]
@ -15,6 +35,7 @@ mod ffi {
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
has_error: bool, has_error: bool,
error_msg: String, error_msg: String,
} }
@ -66,3 +87,17 @@ mod ffi {
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
} }
} }
use ffi::FinishReason;
use text_generation_router::FinishReason as InferFinishReason;
impl From<FinishReason> for InferFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::StopWords => InferFinishReason::StopSequence,
FinishReason::MaxLength => InferFinishReason::Length,
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
}
}
}

View File

@ -10,7 +10,7 @@ use tokio::sync::TryAcquireError;
use tokio::task::spawn_blocking; use tokio::task::spawn_blocking;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn}; use tracing::{debug, error, info, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
}; };
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token}; use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
};
use crate::utils::first_line; use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
@ -40,6 +42,7 @@ struct DecodedToken {
id: u32, id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
} }
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id, id: step.token_id,
log_prob: step.log_prob, log_prob: step.log_prob,
is_final: step.is_final, is_final: step.is_final,
finish_reason: step.finish_reason,
}) })
} else { } else {
Err(GenerationError(step.error_msg.clone())) Err(GenerationError(step.error_msg.clone()))
@ -133,13 +137,16 @@ fn executor_status_looper(
Ok(decoded_token) => { Ok(decoded_token) => {
post_process_decoded_token(&tokenizer, ctx, decoded_token) post_process_decoded_token(&tokenizer, ctx, decoded_token)
} }
Err(err) => Err(err) Err(err) => Err(err),
}; };
// Attempt to send back the response to the client // Attempt to send back the response to the client
if let Err(_) = ctx.streamer.send(response) { if let Err(_) = ctx.streamer.send(response) {
// Client has dropped, remove from tracked requests // Client has dropped, remove from tracked requests
debug!("Client dropped - removing request {} from tracked requests", step.request_id); info!(
"Client dropped - removing request {} from tracked requests",
step.request_id
);
backend.as_mut().cancel(step.request_id); backend.as_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id); let _ = in_flights.remove(&step.request_id);
} }
@ -160,11 +167,14 @@ fn executor_status_looper(
} }
} }
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> { fn post_process_decoded_token(
tokenizer: &Tokenizer,
ctx: &mut GenerationContext,
decoded_token: DecodedToken,
) -> InferResult<InferStreamResponse> {
match tokenizer.decode(&[decoded_token.id], false) { match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => { Ok(text) => {
let is_special = let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token { let token = Token {
id: decoded_token.id, id: decoded_token.id,
text, text,
@ -186,7 +196,7 @@ fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason finish_reason: decoded_token.finish_reason.into(),
seed: None, seed: None,
}; };
@ -248,7 +258,6 @@ unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>); pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
impl TensorRtLlmBackendV2 { impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>( pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -268,12 +277,7 @@ impl TensorRtLlmBackendV2 {
// Executor looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
spawn_blocking(move || { spawn_blocking(move || {
executor_status_looper( executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
max_inflight_requests,
tokenizer,
backend,
executor_receiver,
)
}); });
Ok(TensorRtLlmBackendV2(executor_sender)) Ok(TensorRtLlmBackendV2(executor_sender))

View File

@ -7,9 +7,11 @@ use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2; use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
};
use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; use text_generation_router::{server, Tokenizer};
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -65,11 +67,7 @@ struct Args {
payload_limit: usize, payload_limit: usize,
} }
async fn get_tokenizer( async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN") let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@ -129,14 +127,14 @@ async fn get_tokenizer(
_tokenizer_config_filename, _tokenizer_config_filename,
_preprocessor_config_filename, _preprocessor_config_filename,
_processor_config_filename, _processor_config_filename,
_model_info _model_info,
) = match api { ) = match api {
Type::None => ( Type::None => (
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")), Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None None,
), ),
Type::Api(api) => { Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision( let api_repo = api.repo(Repo::with_revision(
@ -145,7 +143,6 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(), revision.unwrap_or_else(|| "main").to_string(),
)); ));
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
@ -176,38 +173,25 @@ async fn get_tokenizer(
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"), repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None None,
) )
} }
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
// {
// HubTokenizerConfig::from_file(filename)
// } else {
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = { let tokenizer: Tokenizer = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?; py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
Ok(()) Ok(())
}) })
.inspect_err(|err| { .inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}"); tracing::error!("Failed to import python tokenizer {err}");
}) })
.or_else(|err| { .or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref()); let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err) out.ok_or(err)
}) })
.expect("We cannot load a tokenizer"); .expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json"; let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok) Tokenizer::Rust(tok)
@ -291,16 +275,13 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
// Create the backend // Create the backend
match get_tokenizer( match get_tokenizer(&tokenizer_name, revision.as_deref())
&tokenizer_name, .await
tokenizer_config_path.as_deref(), .expect("Failed to retrieve tokenizer implementation")
revision.as_deref(), {
) Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
.await "Failed to retrieve Rust based tokenizer".to_string(),
.expect("Failed to retrieve tokenizer implementation") { )),
Tokenizer::Python { .. } => {
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string()))
}
Tokenizer::Rust(tokenizer) => { Tokenizer::Rust(tokenizer) => {
info!("Successfully retrieved tokenizer {}", &tokenizer_name); info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new( let backend = TensorRtLlmBackendV2::new(
@ -337,9 +318,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit, payload_limit,
).await?; )
.await?;
Ok(()) Ok(())
} }
} }
} }

View File

@ -8,13 +8,13 @@
#include "backend.hpp" #include "backend.hpp"
using namespace huggingface::tgi::backends::trtllm; using namespace huggingface::tgi::backends::trtllm;
TEST_CASE("parse generation_config.json all set", "[generation_config_t]") TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
{ {
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; const json config_j = {{"temperature", 0.6},
{"top_p", 0.95},
{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j); const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);
@ -35,7 +36,7 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
TEST_CASE("parse generation_config.json default", "[generation_config_t]") TEST_CASE("parse generation_config.json default", "[generation_config_t]")
{ {
const json config_j = {{"eos_token_id", {1,2,3}}}; const json config_j = {{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j); const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);