From 31a6065fac73b00ad1b2fceb770fb51e82eb3e65 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 26 Sep 2024 23:31:07 +0200 Subject: [PATCH] Add some utility functions in tgiccl for now --- csrc/CMakeLists.txt | 2 +- csrc/cmake/nvshmem.cmake | 15 ----------- csrc/cmake/spdlog.cmake | 6 +++++ csrc/tgiccl/CMakeLists.txt | 15 +++++++---- csrc/tgiccl/TgiCclBackend.cpp | 11 ++++++++ csrc/tgiccl/TgiCclBackend.hpp | 29 ++++++++++++++++++++ csrc/tgiccl/test_tgiccl.cpp | 11 ++++++++ csrc/tgiccl/tgiccl.hpp | 50 +++++++++++++++++++++++++++++------ 8 files changed, 110 insertions(+), 29 deletions(-) delete mode 100644 csrc/cmake/nvshmem.cmake create mode 100644 csrc/cmake/spdlog.cmake create mode 100644 csrc/tgiccl/TgiCclBackend.cpp create mode 100644 csrc/tgiccl/TgiCclBackend.hpp create mode 100644 csrc/tgiccl/test_tgiccl.cpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9501f9bc3..409fc2977 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -16,6 +16,7 @@ option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library" # Add some modules include(FetchContent) +include(cmake/spdlog.cmake) # Let's find LibTorch include(cmake/torch.cmake) @@ -28,6 +29,5 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") # Include submodules if (${TGI_BUILD_CCL}) - include(cmake/nvshmem.cmake) add_subdirectory(tgiccl) endif () diff --git a/csrc/cmake/nvshmem.cmake b/csrc/cmake/nvshmem.cmake deleted file mode 100644 index 364da5c20..000000000 --- a/csrc/cmake/nvshmem.cmake +++ /dev/null @@ -1,15 +0,0 @@ -if (CMAKE_BUILD_TYPE STREQUAL "Release") - set(NVSHMEM_DEBUG OFF) - set(NVSHMEM_VERBOSE OFF) -else () - set(NVSHMEM_DEBUG ON) - set(NVSHMEM_VERBOSE ON) -endif () - -fetchcontent_declare( - nvshmem - URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz - DOWNLOAD_EXTRACT_TIMESTAMP -) - -fetchcontent_makeavailable(nvshmem) \ No newline at end of file diff --git a/csrc/cmake/spdlog.cmake b/csrc/cmake/spdlog.cmake new file mode 100644 index 000000000..d4e0c491c --- /dev/null +++ b/csrc/cmake/spdlog.cmake @@ -0,0 +1,6 @@ +fetchcontent_declare( + spdlog + URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz +) + +fetchcontent_makeavailable(spdlog) \ No newline at end of file diff --git a/csrc/tgiccl/CMakeLists.txt b/csrc/tgiccl/CMakeLists.txt index 36a6e8290..cda79d490 100644 --- a/csrc/tgiccl/CMakeLists.txt +++ b/csrc/tgiccl/CMakeLists.txt @@ -1,7 +1,12 @@ -project(tgiccl) +project(tgiccl LANGUAGES C CXX CUDA) -set(TGICCL_HEADER_FILES tgiccl.hpp) -#set(TGICCL_SOURCE_FILES) +set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp) +set(TGICCL_SOURCES TgiCclBackend.cpp) -add_library(tgiccl SHARED ${TGICCL_HEADER_FILES}) -target_link_libraries(tgiccl nvshmem) \ No newline at end of file +find_package(CUDAToolkit REQUIRED) + +add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES}) +target_link_libraries(tgiccl PUBLIC spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES}) + +add_executable(test_tgiccl test_tgiccl.cpp) +target_link_libraries(test_tgiccl tgiccl spdlog::spdlog) \ No newline at end of file diff --git a/csrc/tgiccl/TgiCclBackend.cpp b/csrc/tgiccl/TgiCclBackend.cpp new file mode 100644 index 000000000..1c4250bc0 --- /dev/null +++ b/csrc/tgiccl/TgiCclBackend.cpp @@ -0,0 +1,11 @@ +// +// Created by Morgan Funtowicz on 26/09/24. +// + +#include "TgiCclBackend.hpp" + + +void huggingface::tgi::tgiccl::InitTgiCcl() +{ + +} \ No newline at end of file diff --git a/csrc/tgiccl/TgiCclBackend.hpp b/csrc/tgiccl/TgiCclBackend.hpp new file mode 100644 index 000000000..cbf6e0e11 --- /dev/null +++ b/csrc/tgiccl/TgiCclBackend.hpp @@ -0,0 +1,29 @@ +// +// Created by Morgan Funtowicz on 26/09/24. +// + +#ifndef TGICCLPROCESSGROUP_H +#define TGICCLPROCESSGROUP_H + +#include +#include + + +namespace huggingface::tgi::tgiccl +{ + void InitTgiCcl(); + + class TgiCclBackend final : c10d::Backend { + public: + TgiCclBackend(const int rank, const int size): Backend(rank, size) + { + SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size); + } + + c10::intrusive_ptr allreduce(std::vector&, const c10d::AllreduceOptions&) override; + }; +} + + + +#endif //TGICCLPROCESSGROUP_H diff --git a/csrc/tgiccl/test_tgiccl.cpp b/csrc/tgiccl/test_tgiccl.cpp new file mode 100644 index 000000000..68eee85ec --- /dev/null +++ b/csrc/tgiccl/test_tgiccl.cpp @@ -0,0 +1,11 @@ +// +// Created by morgan on 26/09/24. +// + +#include "tgiccl.hpp" + +int main() { + auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1); + auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2); + auto c = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3); +} \ No newline at end of file diff --git a/csrc/tgiccl/tgiccl.hpp b/csrc/tgiccl/tgiccl.hpp index 10d1ffb40..1bde03b0e 100644 --- a/csrc/tgiccl/tgiccl.hpp +++ b/csrc/tgiccl/tgiccl.hpp @@ -5,18 +5,52 @@ #ifndef TEXT_GENERATION_INFERENCE_TGICCL_H #define TEXT_GENERATION_INFERENCE_TGICCL_H -#include -#include +#include -constexpr const char *CLL_BACKEND_NAME = "tgiccl"; +#include -namespace huggingface::tgi { - class TgiCcl { - private: +#include "TgiCclBackend.hpp" - public: +constexpr auto CLL_BACKEND_NAME = "tgiccl"; + +namespace huggingface::tgi::tgiccl { + + static std::once_flag NVML_INIT_FLAG; +#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2); + + inline std::optional GetDeviceByIndex(const size_t index) + { + ENSURE_NVML_INIT(); + + nvmlDevice_t device; + if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS) + return std::optional{ device }; + + return std::nullopt; + } + + inline bool IsNvLinkAvailable(const int from, const int to) + { + ENSURE_NVML_INIT(); + + // Get devices + const auto devFrom = GetDeviceByIndex(from); + const auto devTo = GetDeviceByIndex(to); + + if(!devFrom.has_value()) + SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from); + + if(!devTo.has_value()) + SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to); + + // Query link between both + nvmlGpuP2PStatus_t status; + if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS) + SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to); + + return status == NVML_P2P_STATUS_OK; + } - }; } #endif //TEXT_GENERATION_INFERENCE_TGICCL_H