From f1c6eacd7500250592a3f16ebdbcdca7458e20be Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 25 Sep 2024 12:27:28 +0000 Subject: [PATCH] feat(tgiccl): initial commit for custom tgiccl backend --- csrc/CMakeLists.txt | 33 +++++++++ csrc/cmake/nvshmem.cmake | 15 ++++ csrc/cmake/torch.cmake | 148 +++++++++++++++++++++++++++++++++++++ csrc/tgiccl/CMakeLists.txt | 7 ++ csrc/tgiccl/tgiccl.hpp | 22 ++++++ 5 files changed, 225 insertions(+) create mode 100644 csrc/CMakeLists.txt create mode 100644 csrc/cmake/nvshmem.cmake create mode 100644 csrc/cmake/torch.cmake create mode 100644 csrc/tgiccl/CMakeLists.txt create mode 100644 csrc/tgiccl/tgiccl.hpp diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt new file mode 100644 index 00000000..9501f9bc --- /dev/null +++ b/csrc/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.22) +project(text-generation-inference LANGUAGES C CXX CUDA) + +# Update some policies +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif () + + +# Define some overall constants +set(CMAKE_CXX_STANDARD 20) +set(TORCH_VERSION "2.3.1" "Version of PyTorch to build against") + +# Define options +option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library" ON) + +# Add some modules +include(FetchContent) + +# Let's find LibTorch +include(cmake/torch.cmake) +find_package(Python3 COMPONENTS Interpreter) +ProbeForPyTorchInstall() +ConfigurePyTorch() + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +# Include submodules +if (${TGI_BUILD_CCL}) + include(cmake/nvshmem.cmake) + add_subdirectory(tgiccl) +endif () diff --git a/csrc/cmake/nvshmem.cmake b/csrc/cmake/nvshmem.cmake new file mode 100644 index 00000000..364da5c2 --- /dev/null +++ b/csrc/cmake/nvshmem.cmake @@ -0,0 +1,15 @@ +if (CMAKE_BUILD_TYPE STREQUAL "Release") + set(NVSHMEM_DEBUG OFF) + set(NVSHMEM_VERBOSE OFF) +else () + set(NVSHMEM_DEBUG ON) + set(NVSHMEM_VERBOSE ON) +endif () + +fetchcontent_declare( + nvshmem + URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz + DOWNLOAD_EXTRACT_TIMESTAMP +) + +fetchcontent_makeavailable(nvshmem) \ No newline at end of file diff --git a/csrc/cmake/torch.cmake b/csrc/cmake/torch.cmake new file mode 100644 index 00000000..aa069d83 --- /dev/null +++ b/csrc/cmake/torch.cmake @@ -0,0 +1,148 @@ +# ProbeForPyTorchInstall +# Attempts to find a Torch installation and set the Torch_ROOT variable +# based on introspecting the python environment. This allows a subsequent +# call to find_package(Torch) to work. +function(ProbeForPyTorchInstall) + if (Torch_ROOT) + message(STATUS "Using cached Torch root = ${Torch_ROOT}") + else () + message(STATUS "Checking for PyTorch using ${Python3_EXECUTABLE} ...") + execute_process( + COMMAND ${Python3_EXECUTABLE} + -c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE PYTORCH_STATUS + OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR) + if (NOT PYTORCH_STATUS EQUAL "0") + message(STATUS "Unable to 'import torch' with ${Python3_EXECUTABLE} (fallback to explicit config)") + return() + endif () + message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}") + + set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING + "Torch configure directory" FORCE) + endif () +endfunction() + + +# ConfigurePyTorch +# Extensions compiled against PyTorch must be ABI-compatible with PyTorch. +# On Linux, there are two components to this: +# 1) Dual ABI settings for libstdc++ +# See https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_dual_abi.html +# For this, PyTorch helpfully provides a function to check which ABI it was +# compiled against. +# 2) C++ ABI compatibility version +# See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html (Sec 5/6) +# The second is a bit more complicated. GCC has official compatibility strings +# which can be specified by -fabi-version. Clang has no notion of ABI +# versioning (https://lists.llvm.org/pipermail/cfe-dev/2015-June/043735.html). +# Separately, pybind11 keeps an internal variable which records its ABI info +# (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences +# in this variable between torch-mlir and PyTorch will cause type errors. +# Thus, our best option is to: +# a) Identify which ABI version PyTorch was compiled with +# b) Tell gcc to use that version +# or +# c) Tell clang to pretend to use it and hope it's ABI-compatible, and +# tell pybind to pretend we're gcc. +# +# MacOS does not have a dual ABI problem. +# FIXME: I don't know if MacOS needs ABI compatibility version flags. +# +# In the future, we may want to switch away from custom building these +# extensions and instead rely on the Torch machinery directly (definitely want +# to do that for official builds). +function(ConfigurePyTorch) + message(STATUS "Checking PyTorch ABI settings...") + if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + # Check dual ABI setting first + execute_process( + COMMAND ${Python3_EXECUTABLE} + -c "import torch; import sys; sys.stdout.write('1' if torch.compiled_with_cxx11_abi() else '0')" + RESULT_VARIABLE _result + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE _use_cxx11_abi) + if (_result) + message(FATAL_ERROR "Failed to determine C++ Dual ABI: ${Python3_EXECUTABLE} -> ${_result}") + endif () + message(STATUS "PyTorch C++ Dual ABI setting: \"${_use_cxx11_abi}\"") + + # Check ABI compatibility version + execute_process( + COMMAND ${Python3_EXECUTABLE} + -c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))" + RESULT_VARIABLE _result + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE _cxx_abi_version) + if (_result) + message(FATAL_ERROR "Failed to determine C++ ABI version") + endif () + message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") + + # Specialize compile flags for compiler + if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + else () + message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") + return() + endif () + set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) + endif () +endfunction() + +function(ConfigureLibTorch) + message(STATUS "Checking LibTorch ABI settings...") + if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + message(STATUS "libtorch_python is ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so") + # Check dual ABI setting first + execute_process( + COMMAND bash "-c" "cat ${TORCH_INSTALL_PREFIX}/share/cmake/Torch/TorchConfig.cmake | egrep -o '_GLIBCXX_USE_CXX11_ABI=[0-1]' | egrep -o '.$'" + RESULT_VARIABLE _result + OUTPUT_VARIABLE _use_cxx11_abi + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (_result) + message(FATAL_ERROR "Failed to determine LibTorch C++ Dual ABI") + endif () + message(STATUS "LibTorch C++ Dual ABI setting: \"${_use_cxx11_abi}\"") + + # Check ABI compatibility version + execute_process( + COMMAND bash "-c" "strings ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so | egrep '^_cxxabi[0-9]{4}' | egrep -o '..$'" + RESULT_VARIABLE _result + OUTPUT_VARIABLE _cxx_abi_version + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (_result) + message(FATAL_ERROR "Failed to determine LibTorch C++ ABI version") + endif () + message(STATUS "LibTorch C++ ABI version: \"${_cxx_abi_version}\"") + + # Specialize compile flags for compiler + if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + else () + message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") + return() + endif () + set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) + endif () +endfunction() + +function(torch_mlir_python_target_compile_options target) + target_compile_options(${target} PRIVATE + $<$,$,$>: + # Enable RTTI and exceptions. + -frtti -fexceptions + # Noisy pybind warnings + -Wno-unused-value + -Wno-covered-switch-default + > + $<$: + # Enable RTTI and exceptions. + /EHsc /GR> + ) +endfunction() \ No newline at end of file diff --git a/csrc/tgiccl/CMakeLists.txt b/csrc/tgiccl/CMakeLists.txt new file mode 100644 index 00000000..36a6e829 --- /dev/null +++ b/csrc/tgiccl/CMakeLists.txt @@ -0,0 +1,7 @@ +project(tgiccl) + +set(TGICCL_HEADER_FILES tgiccl.hpp) +#set(TGICCL_SOURCE_FILES) + +add_library(tgiccl SHARED ${TGICCL_HEADER_FILES}) +target_link_libraries(tgiccl nvshmem) \ No newline at end of file diff --git a/csrc/tgiccl/tgiccl.hpp b/csrc/tgiccl/tgiccl.hpp new file mode 100644 index 00000000..10d1ffb4 --- /dev/null +++ b/csrc/tgiccl/tgiccl.hpp @@ -0,0 +1,22 @@ +// +// Created by mfuntowicz on 9/25/24. +// + +#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H +#define TEXT_GENERATION_INFERENCE_TGICCL_H + +#include +#include + +constexpr const char *CLL_BACKEND_NAME = "tgiccl"; + +namespace huggingface::tgi { + class TgiCcl { + private: + + public: + + }; +} + +#endif //TEXT_GENERATION_INFERENCE_TGICCL_H