diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt index 49e597d0..9c1f3436 100644 --- a/backends/trtllm/CMakeLists.txt +++ b/backends/trtllm/CMakeLists.txt @@ -58,9 +58,15 @@ target_include_directories(tgi_trtllm_backend_impl PRIVATE # $ ) target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") -target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) +if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) +else() + target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) +endif () + # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) @@ -77,9 +83,15 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp) target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") - target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) + target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) + if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) + else() + target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm) + endif () + if(CMAKE_BUILD_TYPE MATCHES "Debug") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp index e58a7e1a..ae097405 100644 --- a/backends/trtllm/tests/test_backend.cpp +++ b/backends/trtllm/tests/test_backend.cpp @@ -72,7 +72,7 @@ TEST_CASE("parse generation_config.json empty", "[generation_config_t]") REQUIRE(generation_config2.stop_words.empty()); } -TEST_CASE("parallel_config", "[backend_workspace_t]") +TEST_CASE("parallel_config single", "[backend_workspace_t]") { // Generate temporary folder const auto tmp_p = std::filesystem::temp_directory_path(); @@ -88,13 +88,65 @@ TEST_CASE("parallel_config", "[backend_workspace_t]") o_generation_config << R"({"eos_token_id": []})"_json; o_generation_config.close(); - const auto workspace = backend_workspace_t(absolute(tmp_p).generic_string(), absolute(tmp_p).generic_string()); + const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string()); const auto parallel = workspace.parallel_config(); REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR); + REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI); + std::filesystem::remove(config_p); + std::filesystem::remove(generation_config_p); +} + +TEST_CASE("parallel_config multi", "[backend_workspace_t]") +{ + // Generate temporary folder + const auto tmp_p = std::filesystem::temp_directory_path(); + const auto config_p = tmp_p / "config.json"; + const auto generation_config_p = tmp_p / "generation_config.json"; + + // Generate content + std::ofstream o_config(config_p); + o_config << R"({"pretrained_config": {"mapping": {"world_size": 1}}})"_json; + o_config.close(); + + std::ofstream o_generation_config(generation_config_p); + o_generation_config << R"({"eos_token_id": []})"_json; + o_generation_config.close(); + + const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string()); + const auto parallel = workspace.parallel_config(); + REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER); + REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI); + + std::filesystem::remove(config_p); + std::filesystem::remove(generation_config_p); } TEST_CASE("executor_config", "[backend_workspace_t]") { +} + +TEST_CASE("sampling_params_t to tle::SamplingConfig", "[backend_t]") +{ + const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014}; + const auto config = static_cast(params); + + REQUIRE(config.getTopK().has_value()); + REQUIRE(config.getTopK().value() == params.top_k); + + REQUIRE(config.getSeed().has_value()); + REQUIRE(config.getSeed().value() == params.seed); + + REQUIRE(config.getTopP().has_value()); + REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f)); + + REQUIRE(config.getRepetitionPenalty().has_value()); + REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f)); + + REQUIRE(config.getFrequencyPenalty().has_value()); + REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f)); + + REQUIRE(config.getTemperature().has_value()); + REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f)); } \ No newline at end of file