From 0baa0173a39c3b16dd0f9dd025a0f9a94ab477d9 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 10 Dec 2024 16:51:22 +0100 Subject: [PATCH] feat(trtllm): expose finish reason to Rust --- backends/trtllm/CMakeLists.txt | 63 +++++++----- backends/trtllm/cmake/spdlog.cmake | 8 +- backends/trtllm/cmake/trtllm.cmake | 6 +- backends/trtllm/csrc/backend.cpp | 19 ++-- backends/trtllm/csrc/ffi.hpp | 128 +++++++++++++++---------- backends/trtllm/src/lib.rs | 35 +++++++ backends/trtllm/tests/test_backend.cpp | 18 ++-- 7 files changed, 178 insertions(+), 99 deletions(-) diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt index 8ba8df4c..b369c370 100644 --- a/backends/trtllm/CMakeLists.txt +++ b/backends/trtllm/CMakeLists.txt @@ -1,11 +1,19 @@ cmake_minimum_required(VERSION 3.20) -if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug") +if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER) find_program(CCACHE_EXECUTABLE "ccache") if (CCACHE_EXECUTABLE) message(STATUS "Using ccache") - set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE) + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) endif () +else () + find_program(CCACHE_EXECUTABLE ${CMAKE_CXX_COMPILER_LAUNCHER}) + message(STATUS "Using user specified cmake cxx compiler launcher: ${CMAKE_CXX_COMPILER_LAUNCHER}") + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) endif () if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") @@ -21,28 +29,37 @@ include(CheckCXXCompilerFlag) option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF) set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") -set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) +find_package(MPI REQUIRED) #### External dependencies #### include(cmake/json.cmake) include(cmake/spdlog.cmake) include(cmake/trtllm.cmake) -if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(TGI_TRTLLM_BACKEND_DEBUG ON) add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1) -endif() +endif () + +if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD}) + message(STATUS "Using lld linker") + add_link_options("-fuse-ld=lld") +endif () # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO) -if(${COMPILER_SUPPORT_WARNING_ON_NVRO}) - set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro") -endif() +if (${COMPILER_SUPPORT_WARNING_ON_NVRO}) + message(STATUS "Enabling non-NVRO detection") + target_compile_options(tgi_trtllm_backend_impl "-Werror -Wnvro") +endif () # Let's build TRTLLM as part of CMake add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") @@ -55,21 +72,20 @@ add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp cs include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) target_include_directories(tgi_trtllm_backend_impl PRIVATE $ -# $ + # $ ) target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) - -if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") - target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) -else() - target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) -endif () +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper) # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) -install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB) +if (NOT ${TGI_TRTLLM_BACKEND_DEBUG}) + install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) +endif () + #### Unit Tests #### if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) @@ -85,18 +101,13 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) + target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper) - if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") - target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) - else() - target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) - endif () - - if(CMAKE_BUILD_TYPE MATCHES "Debug") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") + if (CMAKE_BUILD_TYPE MATCHES "Debug") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=undefined") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=undefined") target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address) - endif() + endif () if(CMAKE_BUILD_TYPE MATCHES "Debug") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake index 45e6790a..e7566cd7 100644 --- a/backends/trtllm/cmake/spdlog.cmake +++ b/backends/trtllm/cmake/spdlog.cmake @@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF) # Define the level at which SPDLOG_ compilation level is defined if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") - add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE) else () - add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) endif () fetchcontent_declare( spdlog -# DOWNLOAD_EXTRACT_TIMESTAMP - URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz + # DOWNLOAD_EXTRACT_TIMESTAMP + URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz ) fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake index 4217892b..d789b1eb 100644 --- a/backends/trtllm/cmake/trtllm.cmake +++ b/backends/trtllm/cmake/trtllm.cmake @@ -14,11 +14,13 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") set(ENABLE_UCX OFF) if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") set(FAST_BUILD ON) - set(NVTX_DISABLE OFF) + set(NVTX_DISABLE ON) + set(INDEX_RANGE_CHECK ON) else () set(FAST_BUILD OFF) set(FAST_MATH ON) - set(NVTX_DISABLE ON) + set(NVTX_DISABLE OFF) + set(INDEX_RANGE_CHECK OFF) endif () find_package(Python3 REQUIRED Interpreter) diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index b50044d8..2151466b 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -1,7 +1,6 @@ #include #include -#include #include "backend.hpp" #include "hardware.hpp" @@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm { if (world_size > 1) { SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); mode = tle::CommunicationMode::kORCHESTRATOR; - orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr, true); + orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr, + true); } else { SPDLOG_INFO("Detected single engine deployment, using leader mode"); } @@ -44,21 +44,22 @@ namespace huggingface::tgi::backends::trtllm { } backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) - : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} + : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} size_t backend_t::num_tokens_ready() const noexcept { return executor_.getNumResponsesReady(); } std::expected - backend_t::submit(std::span token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept { - SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params); - return executor_.enqueueRequest(tle::Request { + backend_t::submit(std::span token_ids, const generation_params_t g_params, + const sampling_params_t s_params) noexcept { + SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params); + return executor_.enqueueRequest(tle::Request{ {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens - static_cast(generation_params.max_new_tokens), + static_cast(g_params.max_new_tokens), true, - (tle::SamplingConfig) sampling_params, - tle::OutputConfig { /* returnLogProbs= */ true }, + (tle::SamplingConfig) s_params, + tle::OutputConfig{ /* returnLogProbs= */ true}, std::nullopt, std::nullopt, std::nullopt, diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index d0342d4b..d14578b8 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -28,20 +28,61 @@ namespace huggingface::tgi::backends::trtllm { #include "backends/trtllm/src/lib.rs.h" + namespace huggingface::tgi::backends::trtllm { std::once_flag backend_initialized_flag; + constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept { + switch (reason) { + case tle::FinishReason::kNOT_FINISHED: + return finish_reason_t::kNOT_FINISHED; + case tle::FinishReason::kSTOP_WORDS: + return finish_reason_t::kSTOP_WORDS; + case tle::FinishReason::kEND_ID: + return finish_reason_t::kEND_ID; + case tle::FinishReason::kLENGTH: + return finish_reason_t::kLENGTH; + default: + std::unreachable(); + } + } + + static auto as_generation_step = [](const tle::Response &r) { + const auto reqId = r.getRequestId(); + if (!r.hasError()) [[likely]] { + const auto result = r.getResult(); + return generation_step_t{ + reqId, + static_cast(result.outputTokenIds[0][0]), + result.logProbs.value()[0][0], + result.isFinal, + as_finish_reason_t(result.finishReasons[0]), + false, + std::string() + }; + } else { + return generation_step_t{ + reqId, + 0, + 0.0, + true, + finish_reason_t::kNOT_FINISHED, + true, + std::move(r.getErrorMsg()) + }; + } + }; + + class tensorrt_llm_backend_t { private: backend_t inner_; public: tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) - : inner_(engine_folder, executor_worker_path) {} + : inner_(engine_folder, executor_worker_path) {} - size_t num_tokens_ready() const noexcept { - return inner_.num_tokens_ready(); - } + size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } request_id_t submit( rust::Slice tokens, @@ -59,13 +100,13 @@ namespace huggingface::tgi::backends::trtllm { // Submit the request to the executor and get back a potential request_id used to track request status const auto signed_tokens = std::vector(tokens.begin(), tokens.end()); const auto maybe_request_id = inner_.submit( - signed_tokens, - {max_new_tokens}, - {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} + signed_tokens, + {max_new_tokens}, + {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} ); // If we do have a value, let's return the request_id - if(maybe_request_id.has_value()) [[likely]] { + if (maybe_request_id.has_value()) [[likely]] { return *maybe_request_id; } else { SPDLOG_WARN("[FFI] Failed to submit request to the executor"); @@ -74,61 +115,45 @@ namespace huggingface::tgi::backends::trtllm { } std::unique_ptr> pull_tokens() noexcept { - if(num_tokens_ready() > 0) [[likely]] { + if (num_tokens_ready() > 0) [[likely]] { const auto responses = inner_.pull_tokens(); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); - // Transform tle::Response to GenerationStep - auto steps = std::make_unique>(); - std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { - const auto reqId = r.getRequestId(); - if (!r.hasError()) [[likely]] { - const auto result = r.getResult(); - return generation_step_t{ - reqId, - static_cast(result.outputTokenIds[0][0]), - result.logProbs.value()[0][0], - result.isFinal, - false, - std::string() - }; - } else { - return generation_step_t{ - reqId, - 0, - 0.0, - true, - true, - std::move(r.getErrorMsg()) - }; - } - }); - return steps; + + // Transform tle::Response to generation_step_t +#ifdef __cpp_lib_ranges_to_container + auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to(); +#else + auto steps = std::vector(); + steps.reserve(responses.size()); + std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step); +#endif + return std::make_unique>(steps); } else { return std::make_unique>(); } } - void cancel(request_id_t requestId) noexcept { - SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); - inner_.cancel(requestId); + void cancel(request_id_t request_id) noexcept { + SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id); + inner_.cancel(request_id); } }; void initialize_logging() { #ifndef TGI_TRTLLM_BACKEND_DEBUG if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) { - std::string log_level(TRTLLM_LOG_LEVEL_CSTR); - std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { - return std::tolower(c); - }); + std::string log_level(TRTLLM_LOG_LEVEL_CSTR); + std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { + return std::tolower(c); + }); - if (log_level == "debug") - spdlog::set_level(spdlog::level::debug); - else - spdlog::set_level(spdlog::level::info); - } + if (log_level == "debug") + spdlog::set_level(spdlog::level::debug); + else + spdlog::set_level(spdlog::level::info); + } #else spdlog::set_level(spdlog::level::debug); #endif @@ -151,11 +176,14 @@ namespace huggingface::tgi::backends::trtllm { } } - std::unique_ptr create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { + std::unique_ptr + create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); return std::make_unique( - std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), - std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format) + std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), + std::filesystem::path::format::auto_format), + std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), + std::filesystem::path::format::auto_format) ); } } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index d6acafa1..08507256 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -6,6 +6,26 @@ mod utils; #[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")] mod ffi { + #[cxx_name = "finish_reason_t"] + #[derive(Debug, Clone, Copy)] + pub enum FinishReason { + /// The request is not finished. + #[cxx_name = "kNOT_FINISHED"] + NotFinished = 0u8, + + /// The request finished because the end id was generated. + #[cxx_name = "kEND_ID"] + EndTokenId = 1u8, + + /// The request finished because a stop word was generated. + #[cxx_name = "kSTOP_WORDS"] + StopWords = 2u8, + + /// The request finished because the maximum number of tokens was reached. + #[cxx_name = "kLENGTH"] + MaxLength = 3u8, + } + /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration #[cxx_name = "generation_step_t"] @@ -15,6 +35,7 @@ mod ffi { token_id: u32, log_prob: f32, is_final: bool, + finish_reason: FinishReason, has_error: bool, error_msg: String, } @@ -66,3 +87,17 @@ mod ffi { fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); } } + +use ffi::FinishReason; +use text_generation_router::FinishReason as InferFinishReason; + +impl From for InferFinishReason { + fn from(reason: FinishReason) -> Self { + match reason { + FinishReason::StopWords => InferFinishReason::StopSequence, + FinishReason::MaxLength => InferFinishReason::Length, + FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken, + _ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"), + } + } +} diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp index 14d92b75..f44cc03f 100644 --- a/backends/trtllm/tests/test_backend.cpp +++ b/backends/trtllm/tests/test_backend.cpp @@ -8,13 +8,13 @@ #include "backend.hpp" - - using namespace huggingface::tgi::backends::trtllm; TEST_CASE("parse generation_config.json all set", "[generation_config_t]") { - const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; + const json config_j = {{"temperature", 0.6}, + {"top_p", 0.95}, + {"eos_token_id", {1, 2, 3}}}; const auto generation_config = generation_config_t(config_j); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); @@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]") REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); - for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list>{{1}, {2}, {3}})) - { + for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1}, + {2}, + {3}})) { // Currently we do not support multi-tokens stop words REQUIRE(lhs.size() == 1); REQUIRE(rhs.size() == 1); @@ -35,7 +36,7 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]") TEST_CASE("parse generation_config.json default", "[generation_config_t]") { - const json config_j = {{"eos_token_id", {1,2,3}}}; + const json config_j = {{"eos_token_id", {1, 2, 3}}}; const auto generation_config = generation_config_t(config_j); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); @@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]") REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); - for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list>{{1}, {2}, {3}})) - { + for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1}, + {2}, + {3}})) { // Currently we do not support multi-tokens stop words REQUIRE(lhs.size() == 1); REQUIRE(rhs.size() == 1);