mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
feat(backend): remove constexpig
This commit is contained in:
parent
881527a544
commit
62530649b8
@ -77,8 +77,8 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
|
||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE CUDA::cudart CUDA::nvml)
|
||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE Catch2::Catch2WithMain tensorrt_llm nlohmann_json::nlohmann_json spdlog::spdlog)
|
||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
||||
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
||||
|
||||
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
||||
|
@ -4,14 +4,14 @@
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include <nvml.h>
|
||||
#include <tensorrt_llm/common/tllmException.h>
|
||||
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <spdlog/pattern_formatter.h>
|
||||
#include <spdlog/fmt/fmt.h>
|
||||
|
||||
#include <backend.hpp>
|
||||
#include <hardware.hpp>
|
||||
|
||||
namespace rust::behavior {
|
||||
template<typename Try, typename Fail>
|
||||
@ -111,7 +111,7 @@ namespace huggingface::tgi::backends::trtllm {
|
||||
}
|
||||
|
||||
void cancel(request_id_t requestId) noexcept {
|
||||
SPDLOG_DEBUG(FMT_STRING("[FFI] cancelling request {:d}"), requestId);
|
||||
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId);
|
||||
inner_.cancel(requestId);
|
||||
}
|
||||
};
|
||||
@ -144,7 +144,7 @@ namespace huggingface::tgi::backends::trtllm {
|
||||
|
||||
const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();
|
||||
if (numGpus.has_value()) {
|
||||
SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||
SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus);
|
||||
} else {
|
||||
SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system");
|
||||
// todo: throw
|
||||
|
@ -16,13 +16,12 @@ namespace huggingface::tgi::hardware::cuda {
|
||||
* Get the number of GPUs on the local machine
|
||||
* @return std::nullopt if no device is available, otherwise >= 1
|
||||
*/
|
||||
std::optional<size_t> get_device_count() {
|
||||
inline std::optional<size_t> get_device_count() {
|
||||
uint32_t numGpus = 0;
|
||||
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
|
||||
return numGpus;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4,15 +4,43 @@
|
||||
|
||||
#include <catch2/catch_all.hpp>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "../csrc/backend.hpp"
|
||||
#include <tensorrt_llm/executor/executor.h>
|
||||
|
||||
#include "backend.hpp"
|
||||
|
||||
|
||||
|
||||
using namespace huggingface::tgi::backends::trtllm;
|
||||
|
||||
TEST_CASE("parse generation_config.json", "[generation_config_t]")
|
||||
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
||||
{
|
||||
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}};
|
||||
const auto generation_config = generation_config_t(config_j);
|
||||
|
||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
|
||||
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));
|
||||
|
||||
// Stop words
|
||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||
|
||||
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
|
||||
{
|
||||
// Currently we do not support multi-tokens stop words
|
||||
REQUIRE(lhs.size() == 1);
|
||||
REQUIRE(rhs.size() == 1);
|
||||
REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("parse generation_config.json default", "[generation_config_t]")
|
||||
{
|
||||
const json config_j = {{"eos_token_id", {1,2,3}}};
|
||||
const auto generation_config = generation_config_t(config_j);
|
||||
|
||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
|
||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||
|
||||
@ -25,8 +53,44 @@ TEST_CASE("parse generation_config.json", "[generation_config_t]")
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
|
||||
{
|
||||
const json config_j = {{"eos_token_id", {}}};
|
||||
const auto generation_config = generation_config_t(config_j);
|
||||
|
||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
|
||||
REQUIRE(generation_config.stop_words.empty());
|
||||
|
||||
const json config_j2 = {};
|
||||
const auto generation_config2 = generation_config_t(config_j);
|
||||
|
||||
REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
|
||||
|
||||
REQUIRE(generation_config2.stop_words.empty());
|
||||
}
|
||||
|
||||
TEST_CASE("parallel_config", "[backend_workspace_t]")
|
||||
{
|
||||
// Generate temporary folder
|
||||
const auto tmp_p = std::filesystem::temp_directory_path();
|
||||
const auto config_p = tmp_p / "config.json";
|
||||
const auto generation_config_p = tmp_p / "generation_config.json";
|
||||
|
||||
// Generate content
|
||||
std::ofstream o_config(config_p);
|
||||
o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json;
|
||||
o_config.close();
|
||||
|
||||
std::ofstream o_generation_config(generation_config_p);
|
||||
o_generation_config << R"({"eos_token_id": []})"_json;
|
||||
o_generation_config.close();
|
||||
|
||||
const auto workspace = backend_workspace_t(absolute(tmp_p).generic_string(), absolute(tmp_p).generic_string());
|
||||
const auto parallel = workspace.parallel_config();
|
||||
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
|
||||
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user