diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt index 9c1f3436..faf4be57 100644 --- a/backends/trtllm/CMakeLists.txt +++ b/backends/trtllm/CMakeLists.txt @@ -1,11 +1,19 @@ cmake_minimum_required(VERSION 3.20) -if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug") +if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER) find_program(CCACHE_EXECUTABLE "ccache") if (CCACHE_EXECUTABLE) message(STATUS "Using ccache") - set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE) + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) endif () +else () + find_program(CCACHE_EXECUTABLE ${CMAKE_CXX_COMPILER_LAUNCHER}) + message(STATUS "Using user specified cmake cxx compiler launcher: ${CMAKE_CXX_COMPILER_LAUNCHER}") + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) + set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE}) endif () if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") @@ -21,28 +29,37 @@ include(CheckCXXCompilerFlag) option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF) set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") -set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) +find_package(MPI REQUIRED) #### External dependencies #### include(cmake/json.cmake) include(cmake/spdlog.cmake) include(cmake/trtllm.cmake) -if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(TGI_TRTLLM_BACKEND_DEBUG ON) add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1) -endif() +endif () + +if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD}) + message(STATUS "Using lld linker") + add_link_options("-fuse-ld=lld") +endif () # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO) -if(${COMPILER_SUPPORT_WARNING_ON_NVRO}) - set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro") -endif() +if (${COMPILER_SUPPORT_WARNING_ON_NVRO}) + message(STATUS "Enabling non-NVRO detection") + target_compile_options(tgi_trtllm_backend_impl "-Werror -Wnvro") +endif () # Let's build TRTLLM as part of CMake add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") @@ -55,21 +72,20 @@ add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp cs include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) target_include_directories(tgi_trtllm_backend_impl PRIVATE $ -# $ + # $ ) target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) - -if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") - target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) -else() - target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) -endif () +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper) # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) -install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB) +if (NOT ${TGI_TRTLLM_BACKEND_DEBUG}) + install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) +endif () + #### Unit Tests #### if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) @@ -85,18 +101,13 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) + target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper) - if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") - target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) - else() - target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) - endif () - - if(CMAKE_BUILD_TYPE MATCHES "Debug") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") + if (CMAKE_BUILD_TYPE MATCHES "Debug") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=undefined") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=undefined") target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address) - endif() + endif () list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) include(CTest) diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs index 0a0f6e6b..ee4e49d1 100644 --- a/backends/trtllm/build.rs +++ b/backends/trtllm/build.rs @@ -15,9 +15,8 @@ const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); // Dependencies const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; -const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 4] = [ ("dylib", "tensorrt_llm"), - ("static", "tensorrt_llm_executor_static"), ("dylib", "tensorrt_llm_nvrtc_wrapper"), ("dylib", "nvinfer_plugin_tensorrt_llm"), ("dylib", "decoder_attention"), @@ -32,6 +31,58 @@ macro_rules! probe { }; } +fn get_compiler_flag( + switch: bool, + true_case: &'static str, + false_case: &'static str, +) -> &'static str { + match switch { + true => true_case, + false => false_case, + } +} + +#[cfg(target_arch = "x86_64")] +fn get_system_install_path(install_path: &PathBuf) -> PathBuf { + install_path.join("lib64") +} + +#[cfg(not(target_arch = "x86_64"))] +fn get_system_install_path(install_path: &PathBuf) -> PathBuf { + install_path.join("lib") +} + +fn get_library_architecture() -> &'static str { + let os = env::var("CARGO_CFG_TARGET_OS").unwrap(); + let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap(); + let env = env::var("CARGO_CFG_TARGET_ENV").unwrap(); + + match os.as_str() { + "linux" => { + if env != "gnu" { + panic!("unsupported linux ABI {env}, only 'gnu' is supported") + } + + match arch.as_str() { + "x86_64" => "x86_64-linux-gnu", + "aarch64" => "aarch64-linux-gnu", + _ => panic!("unsupported linux architecture {arch}"), + } + } + "windows" => { + if env != "msvc" { + panic!("unsupported windows ABI {env}, only 'msvc' is supported") + } + + match arch.as_str() { + "x86_64" => "x86_64-windows-msvc", + _ => panic!("unsupported windows architecture {arch}"), + } + } + _ => panic!("unsupported OS {os}"), + } +} + fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { // Build the backend implementation through CMake let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); @@ -44,7 +95,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf } let mut config = cmake::Config::new("."); - config.uses_cxx11() + config + .uses_cxx11() .generator("Ninja") .profile(match is_debug { true => "Debug", @@ -53,16 +105,28 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf .env("OPT_LEVEL", opt_level) .define("CMAKE_INSTALL_PREFIX", &install_path) .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") - .define("Python3_ROOT_DIR", "../venv") + .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture()) .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define( + "TGI_TRTLLM_BACKEND_DEBUG", + get_compiler_flag(is_debug, "ON", "OFF"), + ) .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path); - // Allow to override which Python to use ... - if let Some(python3) = option_env!("Python3_EXECUTABLE") { - config.define("Python3_EXECUTABLE", python3); - } + if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") { + config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler); + } - config.build(); + if let Some(cxx_compiler_launcher) = option_env!("CMAKE_CXX_COMPILER_LAUNCHER") { + config.define("CMAKE_CXX_COMPILER_LAUNCHER", cxx_compiler_launcher); + } + + // Allow to override which Python to use ... + if let Some(python3) = option_env!("Python3_EXECUTABLE") { + config.define("Python3_EXECUTABLE", python3); + } + + config.build(); // Additional transitive CMake dependencies let deps_folder = out_dir.join("build").join("_deps"); @@ -77,7 +141,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf } // Emit linkage information from the artifacts we just built - let install_lib_path = install_path.join("lib"); + + let install_lib_path = get_system_install_path(&install_path); println!( r"cargo:warning=Adding link search path: {}", @@ -89,11 +154,6 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf } fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { - let ndebug = match is_debug { - true => "1", - false => "0", - }; - CFG.include_prefix = "backends/trtllm"; cxx_build::bridge("src/lib.rs") .static_flag(true) @@ -105,7 +165,10 @@ fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { .include("/usr/local/tensorrt/include") .include("csrc/") .file("csrc/ffi.hpp") - .define("TGI_TRTLLM_BACKEND_DEBUG", ndebug) + .define( + "TGI_TRTLLM_BACKEND_DEBUG", + get_compiler_flag(is_debug, "ON", "OFF"), + ) .compile("tgi_trtllm_backend"); println!("cargo:rerun-if-changed=CMakeLists.txt"); diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake index 45e6790a..e7566cd7 100644 --- a/backends/trtllm/cmake/spdlog.cmake +++ b/backends/trtllm/cmake/spdlog.cmake @@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF) # Define the level at which SPDLOG_ compilation level is defined if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") - add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE) else () - add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) endif () fetchcontent_declare( spdlog -# DOWNLOAD_EXTRACT_TIMESTAMP - URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz + # DOWNLOAD_EXTRACT_TIMESTAMP + URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz ) fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake index 4217892b..d789b1eb 100644 --- a/backends/trtllm/cmake/trtllm.cmake +++ b/backends/trtllm/cmake/trtllm.cmake @@ -14,11 +14,13 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") set(ENABLE_UCX OFF) if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") set(FAST_BUILD ON) - set(NVTX_DISABLE OFF) + set(NVTX_DISABLE ON) + set(INDEX_RANGE_CHECK ON) else () set(FAST_BUILD OFF) set(FAST_MATH ON) - set(NVTX_DISABLE ON) + set(NVTX_DISABLE OFF) + set(INDEX_RANGE_CHECK OFF) endif () find_package(Python3 REQUIRED Interpreter) diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index b50044d8..2151466b 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -1,7 +1,6 @@ #include #include -#include #include "backend.hpp" #include "hardware.hpp" @@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm { if (world_size > 1) { SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); mode = tle::CommunicationMode::kORCHESTRATOR; - orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr, true); + orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr, + true); } else { SPDLOG_INFO("Detected single engine deployment, using leader mode"); } @@ -44,21 +44,22 @@ namespace huggingface::tgi::backends::trtllm { } backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) - : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} + : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} size_t backend_t::num_tokens_ready() const noexcept { return executor_.getNumResponsesReady(); } std::expected - backend_t::submit(std::span token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept { - SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params); - return executor_.enqueueRequest(tle::Request { + backend_t::submit(std::span token_ids, const generation_params_t g_params, + const sampling_params_t s_params) noexcept { + SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params); + return executor_.enqueueRequest(tle::Request{ {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens - static_cast(generation_params.max_new_tokens), + static_cast(g_params.max_new_tokens), true, - (tle::SamplingConfig) sampling_params, - tle::OutputConfig { /* returnLogProbs= */ true }, + (tle::SamplingConfig) s_params, + tle::OutputConfig{ /* returnLogProbs= */ true}, std::nullopt, std::nullopt, std::nullopt, diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index de2333af..c57f8d18 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -28,20 +28,61 @@ namespace huggingface::tgi::backends::trtllm { #include "backends/trtllm/src/lib.rs.h" + namespace huggingface::tgi::backends::trtllm { std::once_flag backend_initialized_flag; + constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept { + switch (reason) { + case tle::FinishReason::kNOT_FINISHED: + return finish_reason_t::kNOT_FINISHED; + case tle::FinishReason::kSTOP_WORDS: + return finish_reason_t::kSTOP_WORDS; + case tle::FinishReason::kEND_ID: + return finish_reason_t::kEND_ID; + case tle::FinishReason::kLENGTH: + return finish_reason_t::kLENGTH; + default: + std::unreachable(); + } + } + + static auto as_generation_step = [](const tle::Response &r) { + const auto reqId = r.getRequestId(); + if (!r.hasError()) [[likely]] { + const auto result = r.getResult(); + return generation_step_t{ + reqId, + static_cast(result.outputTokenIds[0][0]), + result.logProbs.value()[0][0], + result.isFinal, + as_finish_reason_t(result.finishReasons[0]), + false, + std::string() + }; + } else { + return generation_step_t{ + reqId, + 0, + 0.0, + true, + finish_reason_t::kNOT_FINISHED, + true, + std::move(r.getErrorMsg()) + }; + } + }; + + class tensorrt_llm_backend_t { private: backend_t inner_; public: tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) - : inner_(engine_folder, executor_worker_path) {} + : inner_(engine_folder, executor_worker_path) {} - size_t num_tokens_ready() const noexcept { - return inner_.num_tokens_ready(); - } + size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } request_id_t submit( rust::Slice tokens, @@ -59,13 +100,13 @@ namespace huggingface::tgi::backends::trtllm { // Submit the request to the executor and get back a potential request_id used to track request status const auto signed_tokens = std::vector(tokens.begin(), tokens.end()); const auto maybe_request_id = inner_.submit( - signed_tokens, - {max_new_tokens}, - {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} + signed_tokens, + {max_new_tokens}, + {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} ); // If we do have a value, let's return the request_id - if(maybe_request_id.has_value()) [[likely]] { + if (maybe_request_id.has_value()) [[likely]] { return *maybe_request_id; } else { SPDLOG_WARN("[FFI] Failed to submit request to the executor"); @@ -74,61 +115,45 @@ namespace huggingface::tgi::backends::trtllm { } std::unique_ptr> pull_tokens() noexcept { - if(num_tokens_ready() > 0) [[likely]] { + if (num_tokens_ready() > 0) [[likely]] { const auto responses = inner_.pull_tokens(); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); - // Transform tle::Response to GenerationStep - auto steps = std::make_unique>(); - std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { - const auto reqId = r.getRequestId(); - if (!r.hasError()) [[likely]] { - const auto result = r.getResult(); - return generation_step_t{ - reqId, - static_cast(result.outputTokenIds[0][0]), - result.logProbs.value()[0][0], - result.isFinal, - false, - std::string() - }; - } else { - return generation_step_t{ - reqId, - 0, - 0.0, - true, - true, - std::move(r.getErrorMsg()) - }; - } - }); - return steps; + + // Transform tle::Response to generation_step_t +#ifdef __cpp_lib_ranges_to_container + auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to(); +#else + auto steps = std::vector(); + steps.reserve(responses.size()); + std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step); +#endif + return std::make_unique>(steps); } else { return std::make_unique>(); } } - void cancel(request_id_t requestId) noexcept { - SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); - inner_.cancel(requestId); + void cancel(request_id_t request_id) noexcept { + SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id); + inner_.cancel(request_id); } }; void initialize_logging() { #ifndef TGI_TRTLLM_BACKEND_DEBUG if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) { - std::string log_level(TRTLLM_LOG_LEVEL_CSTR); - std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { - return std::tolower(c); - }); + std::string log_level(TRTLLM_LOG_LEVEL_CSTR); + std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { + return std::tolower(c); + }); - if (log_level == "debug") - spdlog::set_level(spdlog::level::debug); - else - spdlog::set_level(spdlog::level::info); - } + if (log_level == "debug") + spdlog::set_level(spdlog::level::debug); + else + spdlog::set_level(spdlog::level::info); + } #else spdlog::set_level(spdlog::level::debug); #endif @@ -151,11 +176,14 @@ namespace huggingface::tgi::backends::trtllm { } } - std::unique_ptr create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { + std::unique_ptr + create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); return std::make_unique( - std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), - std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format) + std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), + std::filesystem::path::format::auto_format), + std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), + std::filesystem::path::format::auto_format) ); } } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index d6acafa1..08507256 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -6,6 +6,26 @@ mod utils; #[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")] mod ffi { + #[cxx_name = "finish_reason_t"] + #[derive(Debug, Clone, Copy)] + pub enum FinishReason { + /// The request is not finished. + #[cxx_name = "kNOT_FINISHED"] + NotFinished = 0u8, + + /// The request finished because the end id was generated. + #[cxx_name = "kEND_ID"] + EndTokenId = 1u8, + + /// The request finished because a stop word was generated. + #[cxx_name = "kSTOP_WORDS"] + StopWords = 2u8, + + /// The request finished because the maximum number of tokens was reached. + #[cxx_name = "kLENGTH"] + MaxLength = 3u8, + } + /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration #[cxx_name = "generation_step_t"] @@ -15,6 +35,7 @@ mod ffi { token_id: u32, log_prob: f32, is_final: bool, + finish_reason: FinishReason, has_error: bool, error_msg: String, } @@ -66,3 +87,17 @@ mod ffi { fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); } } + +use ffi::FinishReason; +use text_generation_router::FinishReason as InferFinishReason; + +impl From for InferFinishReason { + fn from(reason: FinishReason) -> Self { + match reason { + FinishReason::StopWords => InferFinishReason::StopSequence, + FinishReason::MaxLength => InferFinishReason::Length, + FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken, + _ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"), + } + } +} diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 3addd95f..cdf2c8f8 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -10,7 +10,7 @@ use tokio::sync::TryAcquireError; use tokio::task::spawn_blocking; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; use text_generation_router::validation::{Chunk, ValidGenerateRequest}; -use text_generation_router::{FinishReason, Token}; +use text_generation_router::Token; use crate::errors::TensorRtLlmBackendError; -use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; +use crate::ffi::{ + create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl, +}; use crate::utils::first_line; type InferResult = Result; @@ -40,6 +42,7 @@ struct DecodedToken { id: u32, log_prob: f32, is_final: bool, + finish_reason: FinishReason, } impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { @@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { id: step.token_id, log_prob: step.log_prob, is_final: step.is_final, + finish_reason: step.finish_reason, }) } else { Err(GenerationError(step.error_msg.clone())) @@ -133,13 +137,16 @@ fn executor_status_looper( Ok(decoded_token) => { post_process_decoded_token(&tokenizer, ctx, decoded_token) } - Err(err) => Err(err) + Err(err) => Err(err), }; // Attempt to send back the response to the client if let Err(_) = ctx.streamer.send(response) { // Client has dropped, remove from tracked requests - debug!("Client dropped - removing request {} from tracked requests", step.request_id); + info!( + "Client dropped - removing request {} from tracked requests", + step.request_id + ); backend.as_mut().cancel(step.request_id); let _ = in_flights.remove(&step.request_id); } @@ -160,11 +167,14 @@ fn executor_status_looper( } } -fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult { +fn post_process_decoded_token( + tokenizer: &Tokenizer, + ctx: &mut GenerationContext, + decoded_token: DecodedToken, +) -> InferResult { match tokenizer.decode(&[decoded_token.id], false) { Ok(text) => { - let is_special = - tokenizer.get_added_vocabulary().is_special_token(&text); + let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); let token = Token { id: decoded_token.id, text, @@ -186,7 +196,7 @@ fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext let generated_text = GeneratedText { text: text.unwrap(), generated_tokens: ctx.tokens.len() as u32, - finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason + finish_reason: decoded_token.finish_reason.into(), seed: None, }; @@ -248,7 +258,6 @@ unsafe impl Send for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2(UnboundedSender); - impl TensorRtLlmBackendV2 { pub fn new + Send, PP: AsRef + Send>( tokenizer: Tokenizer, @@ -268,12 +277,7 @@ impl TensorRtLlmBackendV2 { // Executor looper is responsible for scheduling and pulling requests state at regular interval spawn_blocking(move || { - executor_status_looper( - max_inflight_requests, - tokenizer, - backend, - executor_receiver, - ) + executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver) }); Ok(TensorRtLlmBackendV2(executor_sender)) diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 9c76bafa..cef225be 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -7,9 +7,11 @@ use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; +use text_generation_router::server::{ + get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer, +}; use text_generation_router::usage_stats::UsageStatsLevel; -use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; -use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer}; +use text_generation_router::{server, Tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -65,11 +67,7 @@ struct Args { payload_limit: usize, } -async fn get_tokenizer( - tokenizer_name: &str, - tokenizer_config_path: Option<&str>, - revision: Option<&str>, -) -> Option { +async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option { // Parse Huggingface hub token let authorization_token = std::env::var("HF_TOKEN") .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) @@ -129,14 +127,14 @@ async fn get_tokenizer( _tokenizer_config_filename, _preprocessor_config_filename, _processor_config_filename, - _model_info + _model_info, ) = match api { Type::None => ( Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), Some(local_path.join("processor_config.json")), - None + None, ), Type::Api(api) => { let api_repo = api.repo(Repo::with_revision( @@ -145,7 +143,6 @@ async fn get_tokenizer( revision.unwrap_or_else(|| "main").to_string(), )); - let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); @@ -176,38 +173,25 @@ async fn get_tokenizer( repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), repo.get("processor_config.json"), - None + None, ) } }; - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - // let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - // { - // HubTokenizerConfig::from_file(filename) - // } else { - // tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) - // }; - - // let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - // tracing::warn!("Could not find tokenizer config locally and no API specified"); - // HubTokenizerConfig::default() - // }); - let tokenizer: Tokenizer = { use pyo3::prelude::*; pyo3::Python::with_gil(|py| -> PyResult<()> { py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?; Ok(()) }) - .inspect_err(|err| { - tracing::error!("Failed to import python tokenizer {err}"); - }) - .or_else(|err| { - let out = legacy_tokenizer_handle(config_filename.as_ref()); - out.ok_or(err) - }) - .expect("We cannot load a tokenizer"); + .inspect_err(|err| { + tracing::error!("Failed to import python tokenizer {err}"); + }) + .or_else(|err| { + let out = legacy_tokenizer_handle(config_filename.as_ref()); + out.ok_or(err) + }) + .expect("We cannot load a tokenizer"); let filename = "out/tokenizer.json"; if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { Tokenizer::Rust(tok) @@ -291,16 +275,13 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { } // Create the backend - match get_tokenizer( - &tokenizer_name, - tokenizer_config_path.as_deref(), - revision.as_deref(), - ) - .await - .expect("Failed to retrieve tokenizer implementation") { - Tokenizer::Python { .. } => { - Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string())) - } + match get_tokenizer(&tokenizer_name, revision.as_deref()) + .await + .expect("Failed to retrieve tokenizer implementation") + { + Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer( + "Failed to retrieve Rust based tokenizer".to_string(), + )), Tokenizer::Rust(tokenizer) => { info!("Successfully retrieved tokenizer {}", &tokenizer_name); let backend = TensorRtLlmBackendV2::new( @@ -337,9 +318,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { max_client_batch_size, usage_stats, payload_limit, - ).await?; + ) + .await?; Ok(()) } } - } diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp index ae097405..239058be 100644 --- a/backends/trtllm/tests/test_backend.cpp +++ b/backends/trtllm/tests/test_backend.cpp @@ -8,13 +8,13 @@ #include "backend.hpp" - - using namespace huggingface::tgi::backends::trtllm; TEST_CASE("parse generation_config.json all set", "[generation_config_t]") { - const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; + const json config_j = {{"temperature", 0.6}, + {"top_p", 0.95}, + {"eos_token_id", {1, 2, 3}}}; const auto generation_config = generation_config_t(config_j); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); @@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]") REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); - for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list>{{1}, {2}, {3}})) - { + for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1}, + {2}, + {3}})) { // Currently we do not support multi-tokens stop words REQUIRE(lhs.size() == 1); REQUIRE(rhs.size() == 1); @@ -35,7 +36,7 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]") TEST_CASE("parse generation_config.json default", "[generation_config_t]") { - const json config_j = {{"eos_token_id", {1,2,3}}}; + const json config_j = {{"eos_token_id", {1, 2, 3}}}; const auto generation_config = generation_config_t(config_j); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); @@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]") REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); - for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list>{{1}, {2}, {3}})) - { + for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1}, + {2}, + {3}})) { // Currently we do not support multi-tokens stop words REQUIRE(lhs.size() == 1); REQUIRE(rhs.size() == 1);