mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
test(backend): more test coverage
This commit is contained in:
parent
62530649b8
commit
cc6bc339e5
@ -58,9 +58,15 @@ target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
|||||||
# $<INSTALL_INTERFACE:csrc>
|
# $<INSTALL_INTERFACE:csrc>
|
||||||
)
|
)
|
||||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
||||||
|
|
||||||
|
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
||||||
|
else()
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
||||||
|
endif ()
|
||||||
|
|
||||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||||
@ -77,9 +83,15 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
|||||||
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
|
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
|
||||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
||||||
|
|
||||||
|
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
||||||
|
else()
|
||||||
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
||||||
|
endif ()
|
||||||
|
|
||||||
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
||||||
|
@ -72,7 +72,7 @@ TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
|
|||||||
REQUIRE(generation_config2.stop_words.empty());
|
REQUIRE(generation_config2.stop_words.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("parallel_config", "[backend_workspace_t]")
|
TEST_CASE("parallel_config single", "[backend_workspace_t]")
|
||||||
{
|
{
|
||||||
// Generate temporary folder
|
// Generate temporary folder
|
||||||
const auto tmp_p = std::filesystem::temp_directory_path();
|
const auto tmp_p = std::filesystem::temp_directory_path();
|
||||||
@ -88,13 +88,65 @@ TEST_CASE("parallel_config", "[backend_workspace_t]")
|
|||||||
o_generation_config << R"({"eos_token_id": []})"_json;
|
o_generation_config << R"({"eos_token_id": []})"_json;
|
||||||
o_generation_config.close();
|
o_generation_config.close();
|
||||||
|
|
||||||
const auto workspace = backend_workspace_t(absolute(tmp_p).generic_string(), absolute(tmp_p).generic_string());
|
const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
|
||||||
const auto parallel = workspace.parallel_config();
|
const auto parallel = workspace.parallel_config();
|
||||||
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
|
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
|
||||||
|
REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
|
||||||
|
|
||||||
|
std::filesystem::remove(config_p);
|
||||||
|
std::filesystem::remove(generation_config_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("parallel_config multi", "[backend_workspace_t]")
|
||||||
|
{
|
||||||
|
// Generate temporary folder
|
||||||
|
const auto tmp_p = std::filesystem::temp_directory_path();
|
||||||
|
const auto config_p = tmp_p / "config.json";
|
||||||
|
const auto generation_config_p = tmp_p / "generation_config.json";
|
||||||
|
|
||||||
|
// Generate content
|
||||||
|
std::ofstream o_config(config_p);
|
||||||
|
o_config << R"({"pretrained_config": {"mapping": {"world_size": 1}}})"_json;
|
||||||
|
o_config.close();
|
||||||
|
|
||||||
|
std::ofstream o_generation_config(generation_config_p);
|
||||||
|
o_generation_config << R"({"eos_token_id": []})"_json;
|
||||||
|
o_generation_config.close();
|
||||||
|
|
||||||
|
const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
|
||||||
|
const auto parallel = workspace.parallel_config();
|
||||||
|
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER);
|
||||||
|
REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
|
||||||
|
|
||||||
|
std::filesystem::remove(config_p);
|
||||||
|
std::filesystem::remove(generation_config_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("executor_config", "[backend_workspace_t]")
|
TEST_CASE("executor_config", "[backend_workspace_t]")
|
||||||
{
|
{
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("sampling_params_t to tle::SamplingConfig", "[backend_t]")
|
||||||
|
{
|
||||||
|
const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014};
|
||||||
|
const auto config = static_cast<tle::SamplingConfig>(params);
|
||||||
|
|
||||||
|
REQUIRE(config.getTopK().has_value());
|
||||||
|
REQUIRE(config.getTopK().value() == params.top_k);
|
||||||
|
|
||||||
|
REQUIRE(config.getSeed().has_value());
|
||||||
|
REQUIRE(config.getSeed().value() == params.seed);
|
||||||
|
|
||||||
|
REQUIRE(config.getTopP().has_value());
|
||||||
|
REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f));
|
||||||
|
|
||||||
|
REQUIRE(config.getRepetitionPenalty().has_value());
|
||||||
|
REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f));
|
||||||
|
|
||||||
|
REQUIRE(config.getFrequencyPenalty().has_value());
|
||||||
|
REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f));
|
||||||
|
|
||||||
|
REQUIRE(config.getTemperature().has_value());
|
||||||
|
REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f));
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user