mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat(trtllm): expose finish reason to Rust
This commit is contained in:
parent
f729f2c59b
commit
0baa0173a3
@ -1,11 +1,19 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
|
||||||
find_program(CCACHE_EXECUTABLE "ccache")
|
find_program(CCACHE_EXECUTABLE "ccache")
|
||||||
if (CCACHE_EXECUTABLE)
|
if (CCACHE_EXECUTABLE)
|
||||||
message(STATUS "Using ccache")
|
message(STATUS "Using ccache")
|
||||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
|
set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
endif ()
|
endif ()
|
||||||
|
else ()
|
||||||
|
find_program(CCACHE_EXECUTABLE ${CMAKE_CXX_COMPILER_LAUNCHER})
|
||||||
|
message(STATUS "Using user specified cmake cxx compiler launcher: ${CMAKE_CXX_COMPILER_LAUNCHER}")
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_EXECUTABLE})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||||
@ -21,28 +29,37 @@ include(CheckCXXCompilerFlag)
|
|||||||
|
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||||
|
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
|
||||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||||
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||||
|
find_package(MPI REQUIRED)
|
||||||
|
|
||||||
#### External dependencies ####
|
#### External dependencies ####
|
||||||
include(cmake/json.cmake)
|
include(cmake/json.cmake)
|
||||||
include(cmake/spdlog.cmake)
|
include(cmake/spdlog.cmake)
|
||||||
include(cmake/trtllm.cmake)
|
include(cmake/trtllm.cmake)
|
||||||
|
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
|
set(TGI_TRTLLM_BACKEND_DEBUG ON)
|
||||||
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
|
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
|
||||||
endif()
|
endif ()
|
||||||
|
|
||||||
|
if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
|
||||||
|
message(STATUS "Using lld linker")
|
||||||
|
add_link_options("-fuse-ld=lld")
|
||||||
|
endif ()
|
||||||
|
|
||||||
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
|
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
|
||||||
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
|
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
|
||||||
if(${COMPILER_SUPPORT_WARNING_ON_NVRO})
|
if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
|
||||||
set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro")
|
message(STATUS "Enabling non-NVRO detection")
|
||||||
endif()
|
target_compile_options(tgi_trtllm_backend_impl "-Werror -Wnvro")
|
||||||
|
endif ()
|
||||||
|
|
||||||
# Let's build TRTLLM as part of CMake
|
# Let's build TRTLLM as part of CMake
|
||||||
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
||||||
@ -55,21 +72,20 @@ add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp cs
|
|||||||
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc>
|
||||||
# $<INSTALL_INTERFACE:csrc>
|
# $<INSTALL_INTERFACE:csrc>
|
||||||
)
|
)
|
||||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
|
||||||
else()
|
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
|
||||||
|
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
|
||||||
|
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
#### Unit Tests ####
|
#### Unit Tests ####
|
||||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||||
@ -85,18 +101,13 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
|||||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
||||||
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
|
||||||
|
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=undefined")
|
||||||
else()
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=undefined")
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
|
||||||
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address)
|
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address)
|
||||||
endif()
|
endif ()
|
||||||
|
|
||||||
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
||||||
|
@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF)
|
|||||||
|
|
||||||
# Define the level at which SPDLOG_ compilation level is defined
|
# Define the level at which SPDLOG_ compilation level is defined
|
||||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
|
||||||
else ()
|
else ()
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
spdlog
|
spdlog
|
||||||
# DOWNLOAD_EXTRACT_TIMESTAMP
|
# DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(spdlog)
|
fetchcontent_makeavailable(spdlog)
|
||||||
|
@ -14,11 +14,13 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
|||||||
set(ENABLE_UCX OFF)
|
set(ENABLE_UCX OFF)
|
||||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
set(FAST_BUILD ON)
|
set(FAST_BUILD ON)
|
||||||
set(NVTX_DISABLE OFF)
|
set(NVTX_DISABLE ON)
|
||||||
|
set(INDEX_RANGE_CHECK ON)
|
||||||
else ()
|
else ()
|
||||||
set(FAST_BUILD OFF)
|
set(FAST_BUILD OFF)
|
||||||
set(FAST_MATH ON)
|
set(FAST_MATH ON)
|
||||||
set(NVTX_DISABLE ON)
|
set(NVTX_DISABLE OFF)
|
||||||
|
set(INDEX_RANGE_CHECK OFF)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
find_package(Python3 REQUIRED Interpreter)
|
find_package(Python3 REQUIRED Interpreter)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <ranges>
|
#include <ranges>
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <spdlog/spdlog.h>
|
|
||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
#include "hardware.hpp"
|
#include "hardware.hpp"
|
||||||
@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
if (world_size > 1) {
|
if (world_size > 1) {
|
||||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||||
mode = tle::CommunicationMode::kORCHESTRATOR;
|
mode = tle::CommunicationMode::kORCHESTRATOR;
|
||||||
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr, true);
|
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
|
||||||
|
true);
|
||||||
} else {
|
} else {
|
||||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||||
}
|
}
|
||||||
@ -51,14 +51,15 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::expected<request_id_t, backend_error_t>
|
std::expected<request_id_t, backend_error_t>
|
||||||
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept {
|
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
|
||||||
SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params);
|
const sampling_params_t s_params) noexcept {
|
||||||
return executor_.enqueueRequest(tle::Request {
|
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
|
||||||
|
return executor_.enqueueRequest(tle::Request{
|
||||||
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
||||||
static_cast<tle::SizeType32>(generation_params.max_new_tokens),
|
static_cast<tle::SizeType32>(g_params.max_new_tokens),
|
||||||
true,
|
true,
|
||||||
(tle::SamplingConfig) sampling_params,
|
(tle::SamplingConfig) s_params,
|
||||||
tle::OutputConfig { /* returnLogProbs= */ true },
|
tle::OutputConfig{ /* returnLogProbs= */ true},
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
|
@ -28,9 +28,52 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
|
|
||||||
#include "backends/trtllm/src/lib.rs.h"
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::trtllm {
|
namespace huggingface::tgi::backends::trtllm {
|
||||||
std::once_flag backend_initialized_flag;
|
std::once_flag backend_initialized_flag;
|
||||||
|
|
||||||
|
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
|
||||||
|
switch (reason) {
|
||||||
|
case tle::FinishReason::kNOT_FINISHED:
|
||||||
|
return finish_reason_t::kNOT_FINISHED;
|
||||||
|
case tle::FinishReason::kSTOP_WORDS:
|
||||||
|
return finish_reason_t::kSTOP_WORDS;
|
||||||
|
case tle::FinishReason::kEND_ID:
|
||||||
|
return finish_reason_t::kEND_ID;
|
||||||
|
case tle::FinishReason::kLENGTH:
|
||||||
|
return finish_reason_t::kLENGTH;
|
||||||
|
default:
|
||||||
|
std::unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto as_generation_step = [](const tle::Response &r) {
|
||||||
|
const auto reqId = r.getRequestId();
|
||||||
|
if (!r.hasError()) [[likely]] {
|
||||||
|
const auto result = r.getResult();
|
||||||
|
return generation_step_t{
|
||||||
|
reqId,
|
||||||
|
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||||
|
result.logProbs.value()[0][0],
|
||||||
|
result.isFinal,
|
||||||
|
as_finish_reason_t(result.finishReasons[0]),
|
||||||
|
false,
|
||||||
|
std::string()
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return generation_step_t{
|
||||||
|
reqId,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
true,
|
||||||
|
finish_reason_t::kNOT_FINISHED,
|
||||||
|
true,
|
||||||
|
std::move(r.getErrorMsg())
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class tensorrt_llm_backend_t {
|
class tensorrt_llm_backend_t {
|
||||||
private:
|
private:
|
||||||
backend_t inner_;
|
backend_t inner_;
|
||||||
@ -39,9 +82,7 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
|
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
|
||||||
: inner_(engine_folder, executor_worker_path) {}
|
: inner_(engine_folder, executor_worker_path) {}
|
||||||
|
|
||||||
size_t num_tokens_ready() const noexcept {
|
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
|
||||||
return inner_.num_tokens_ready();
|
|
||||||
}
|
|
||||||
|
|
||||||
request_id_t submit(
|
request_id_t submit(
|
||||||
rust::Slice<const uint32_t> tokens,
|
rust::Slice<const uint32_t> tokens,
|
||||||
@ -65,7 +106,7 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// If we do have a value, let's return the request_id
|
// If we do have a value, let's return the request_id
|
||||||
if(maybe_request_id.has_value()) [[likely]] {
|
if (maybe_request_id.has_value()) [[likely]] {
|
||||||
return *maybe_request_id;
|
return *maybe_request_id;
|
||||||
} else {
|
} else {
|
||||||
SPDLOG_WARN("[FFI] Failed to submit request to the executor");
|
SPDLOG_WARN("[FFI] Failed to submit request to the executor");
|
||||||
@ -74,45 +115,29 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {
|
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {
|
||||||
if(num_tokens_ready() > 0) [[likely]] {
|
if (num_tokens_ready() > 0) [[likely]] {
|
||||||
const auto responses = inner_.pull_tokens();
|
const auto responses = inner_.pull_tokens();
|
||||||
|
|
||||||
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
||||||
// Transform tle::Response to GenerationStep
|
|
||||||
auto steps = std::make_unique<std::vector<generation_step_t>>();
|
// Transform tle::Response to generation_step_t
|
||||||
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
|
#ifdef __cpp_lib_ranges_to_container
|
||||||
const auto reqId = r.getRequestId();
|
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
|
||||||
if (!r.hasError()) [[likely]] {
|
#else
|
||||||
const auto result = r.getResult();
|
auto steps = std::vector<generation_step_t>();
|
||||||
return generation_step_t{
|
steps.reserve(responses.size());
|
||||||
reqId,
|
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
|
||||||
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
#endif
|
||||||
result.logProbs.value()[0][0],
|
return std::make_unique<std::vector<generation_step_t>>(steps);
|
||||||
result.isFinal,
|
|
||||||
false,
|
|
||||||
std::string()
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
return generation_step_t{
|
|
||||||
reqId,
|
|
||||||
0,
|
|
||||||
0.0,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
std::move(r.getErrorMsg())
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return steps;
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return std::make_unique<std::vector<generation_step_t>>();
|
return std::make_unique<std::vector<generation_step_t>>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cancel(request_id_t requestId) noexcept {
|
void cancel(request_id_t request_id) noexcept {
|
||||||
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId);
|
SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
|
||||||
inner_.cancel(requestId);
|
inner_.cancel(request_id);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -151,11 +176,14 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<tensorrt_llm_backend_t> create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
std::unique_ptr<tensorrt_llm_backend_t>
|
||||||
|
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
||||||
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
||||||
return std::make_unique<tensorrt_llm_backend_t>(
|
return std::make_unique<tensorrt_llm_backend_t>(
|
||||||
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format),
|
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
|
||||||
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format)
|
std::filesystem::path::format::auto_format),
|
||||||
|
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
|
||||||
|
std::filesystem::path::format::auto_format)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,26 @@ mod utils;
|
|||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
#[cxx_name = "finish_reason_t"]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
/// The request is not finished.
|
||||||
|
#[cxx_name = "kNOT_FINISHED"]
|
||||||
|
NotFinished = 0u8,
|
||||||
|
|
||||||
|
/// The request finished because the end id was generated.
|
||||||
|
#[cxx_name = "kEND_ID"]
|
||||||
|
EndTokenId = 1u8,
|
||||||
|
|
||||||
|
/// The request finished because a stop word was generated.
|
||||||
|
#[cxx_name = "kSTOP_WORDS"]
|
||||||
|
StopWords = 2u8,
|
||||||
|
|
||||||
|
/// The request finished because the maximum number of tokens was reached.
|
||||||
|
#[cxx_name = "kLENGTH"]
|
||||||
|
MaxLength = 3u8,
|
||||||
|
}
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
#[cxx_name = "generation_step_t"]
|
#[cxx_name = "generation_step_t"]
|
||||||
@ -15,6 +35,7 @@ mod ffi {
|
|||||||
token_id: u32,
|
token_id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
finish_reason: FinishReason,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
error_msg: String,
|
error_msg: String,
|
||||||
}
|
}
|
||||||
@ -66,3 +87,17 @@ mod ffi {
|
|||||||
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
|
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use ffi::FinishReason;
|
||||||
|
use text_generation_router::FinishReason as InferFinishReason;
|
||||||
|
|
||||||
|
impl From<FinishReason> for InferFinishReason {
|
||||||
|
fn from(reason: FinishReason) -> Self {
|
||||||
|
match reason {
|
||||||
|
FinishReason::StopWords => InferFinishReason::StopSequence,
|
||||||
|
FinishReason::MaxLength => InferFinishReason::Length,
|
||||||
|
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
|
||||||
|
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -8,13 +8,13 @@
|
|||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
using namespace huggingface::tgi::backends::trtllm;
|
using namespace huggingface::tgi::backends::trtllm;
|
||||||
|
|
||||||
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
||||||
{
|
{
|
||||||
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}};
|
const json config_j = {{"temperature", 0.6},
|
||||||
|
{"top_p", 0.95},
|
||||||
|
{"eos_token_id", {1, 2, 3}}};
|
||||||
const auto generation_config = generation_config_t(config_j);
|
const auto generation_config = generation_config_t(config_j);
|
||||||
|
|
||||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
|
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
|
||||||
@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
|||||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||||
|
|
||||||
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
|
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
|
||||||
{
|
{2},
|
||||||
|
{3}})) {
|
||||||
// Currently we do not support multi-tokens stop words
|
// Currently we do not support multi-tokens stop words
|
||||||
REQUIRE(lhs.size() == 1);
|
REQUIRE(lhs.size() == 1);
|
||||||
REQUIRE(rhs.size() == 1);
|
REQUIRE(rhs.size() == 1);
|
||||||
@ -35,7 +36,7 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
|||||||
|
|
||||||
TEST_CASE("parse generation_config.json default", "[generation_config_t]")
|
TEST_CASE("parse generation_config.json default", "[generation_config_t]")
|
||||||
{
|
{
|
||||||
const json config_j = {{"eos_token_id", {1,2,3}}};
|
const json config_j = {{"eos_token_id", {1, 2, 3}}};
|
||||||
const auto generation_config = generation_config_t(config_j);
|
const auto generation_config = generation_config_t(config_j);
|
||||||
|
|
||||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||||
@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]")
|
|||||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||||
|
|
||||||
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
|
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
|
||||||
{
|
{2},
|
||||||
|
{3}})) {
|
||||||
// Currently we do not support multi-tokens stop words
|
// Currently we do not support multi-tokens stop words
|
||||||
REQUIRE(lhs.size() == 1);
|
REQUIRE(lhs.size() == 1);
|
||||||
REQUIRE(rhs.size() == 1);
|
REQUIRE(rhs.size() == 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user