mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat(tgiccl): initial commit for custom tgiccl backend
This commit is contained in:
parent
7efcb5e0ed
commit
f1c6eacd75
33
csrc/CMakeLists.txt
Normal file
33
csrc/CMakeLists.txt
Normal file
@ -0,0 +1,33 @@
|
||||
cmake_minimum_required(VERSION 3.22)
|
||||
project(text-generation-inference LANGUAGES C CXX CUDA)
|
||||
|
||||
# Update some policies
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif ()
|
||||
|
||||
|
||||
# Define some overall constants
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(TORCH_VERSION "2.3.1" "Version of PyTorch to build against")
|
||||
|
||||
# Define options
|
||||
option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library" ON)
|
||||
|
||||
# Add some modules
|
||||
include(FetchContent)
|
||||
|
||||
# Let's find LibTorch
|
||||
include(cmake/torch.cmake)
|
||||
find_package(Python3 COMPONENTS Interpreter)
|
||||
ProbeForPyTorchInstall()
|
||||
ConfigurePyTorch()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
|
||||
# Include submodules
|
||||
if (${TGI_BUILD_CCL})
|
||||
include(cmake/nvshmem.cmake)
|
||||
add_subdirectory(tgiccl)
|
||||
endif ()
|
15
csrc/cmake/nvshmem.cmake
Normal file
15
csrc/cmake/nvshmem.cmake
Normal file
@ -0,0 +1,15 @@
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
set(NVSHMEM_DEBUG OFF)
|
||||
set(NVSHMEM_VERBOSE OFF)
|
||||
else ()
|
||||
set(NVSHMEM_DEBUG ON)
|
||||
set(NVSHMEM_VERBOSE ON)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
nvshmem
|
||||
URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||
)
|
||||
|
||||
fetchcontent_makeavailable(nvshmem)
|
148
csrc/cmake/torch.cmake
Normal file
148
csrc/cmake/torch.cmake
Normal file
@ -0,0 +1,148 @@
|
||||
# ProbeForPyTorchInstall
|
||||
# Attempts to find a Torch installation and set the Torch_ROOT variable
|
||||
# based on introspecting the python environment. This allows a subsequent
|
||||
# call to find_package(Torch) to work.
|
||||
function(ProbeForPyTorchInstall)
|
||||
if (Torch_ROOT)
|
||||
message(STATUS "Using cached Torch root = ${Torch_ROOT}")
|
||||
else ()
|
||||
message(STATUS "Checking for PyTorch using ${Python3_EXECUTABLE} ...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE}
|
||||
-c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE PYTORCH_STATUS
|
||||
OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR)
|
||||
if (NOT PYTORCH_STATUS EQUAL "0")
|
||||
message(STATUS "Unable to 'import torch' with ${Python3_EXECUTABLE} (fallback to explicit config)")
|
||||
return()
|
||||
endif ()
|
||||
message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}")
|
||||
|
||||
set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING
|
||||
"Torch configure directory" FORCE)
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
|
||||
# ConfigurePyTorch
|
||||
# Extensions compiled against PyTorch must be ABI-compatible with PyTorch.
|
||||
# On Linux, there are two components to this:
|
||||
# 1) Dual ABI settings for libstdc++
|
||||
# See https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_dual_abi.html
|
||||
# For this, PyTorch helpfully provides a function to check which ABI it was
|
||||
# compiled against.
|
||||
# 2) C++ ABI compatibility version
|
||||
# See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html (Sec 5/6)
|
||||
# The second is a bit more complicated. GCC has official compatibility strings
|
||||
# which can be specified by -fabi-version. Clang has no notion of ABI
|
||||
# versioning (https://lists.llvm.org/pipermail/cfe-dev/2015-June/043735.html).
|
||||
# Separately, pybind11 keeps an internal variable which records its ABI info
|
||||
# (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences
|
||||
# in this variable between torch-mlir and PyTorch will cause type errors.
|
||||
# Thus, our best option is to:
|
||||
# a) Identify which ABI version PyTorch was compiled with
|
||||
# b) Tell gcc to use that version
|
||||
# or
|
||||
# c) Tell clang to pretend to use it and hope it's ABI-compatible, and
|
||||
# tell pybind to pretend we're gcc.
|
||||
#
|
||||
# MacOS does not have a dual ABI problem.
|
||||
# FIXME: I don't know if MacOS needs ABI compatibility version flags.
|
||||
#
|
||||
# In the future, we may want to switch away from custom building these
|
||||
# extensions and instead rely on the Torch machinery directly (definitely want
|
||||
# to do that for official builds).
|
||||
function(ConfigurePyTorch)
|
||||
message(STATUS "Checking PyTorch ABI settings...")
|
||||
if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||
# Check dual ABI setting first
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE}
|
||||
-c "import torch; import sys; sys.stdout.write('1' if torch.compiled_with_cxx11_abi() else '0')"
|
||||
RESULT_VARIABLE _result
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE _use_cxx11_abi)
|
||||
if (_result)
|
||||
message(FATAL_ERROR "Failed to determine C++ Dual ABI: ${Python3_EXECUTABLE} -> ${_result}")
|
||||
endif ()
|
||||
message(STATUS "PyTorch C++ Dual ABI setting: \"${_use_cxx11_abi}\"")
|
||||
|
||||
# Check ABI compatibility version
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE}
|
||||
-c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))"
|
||||
RESULT_VARIABLE _result
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE _cxx_abi_version)
|
||||
if (_result)
|
||||
message(FATAL_ERROR "Failed to determine C++ ABI version")
|
||||
endif ()
|
||||
message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"")
|
||||
|
||||
# Specialize compile flags for compiler
|
||||
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
|
||||
elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
|
||||
else ()
|
||||
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
|
||||
return()
|
||||
endif ()
|
||||
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
function(ConfigureLibTorch)
|
||||
message(STATUS "Checking LibTorch ABI settings...")
|
||||
if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||
message(STATUS "libtorch_python is ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so")
|
||||
# Check dual ABI setting first
|
||||
execute_process(
|
||||
COMMAND bash "-c" "cat ${TORCH_INSTALL_PREFIX}/share/cmake/Torch/TorchConfig.cmake | egrep -o '_GLIBCXX_USE_CXX11_ABI=[0-1]' | egrep -o '.$'"
|
||||
RESULT_VARIABLE _result
|
||||
OUTPUT_VARIABLE _use_cxx11_abi
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
if (_result)
|
||||
message(FATAL_ERROR "Failed to determine LibTorch C++ Dual ABI")
|
||||
endif ()
|
||||
message(STATUS "LibTorch C++ Dual ABI setting: \"${_use_cxx11_abi}\"")
|
||||
|
||||
# Check ABI compatibility version
|
||||
execute_process(
|
||||
COMMAND bash "-c" "strings ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so | egrep '^_cxxabi[0-9]{4}' | egrep -o '..$'"
|
||||
RESULT_VARIABLE _result
|
||||
OUTPUT_VARIABLE _cxx_abi_version
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
if (_result)
|
||||
message(FATAL_ERROR "Failed to determine LibTorch C++ ABI version")
|
||||
endif ()
|
||||
message(STATUS "LibTorch C++ ABI version: \"${_cxx_abi_version}\"")
|
||||
|
||||
# Specialize compile flags for compiler
|
||||
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
|
||||
elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
|
||||
else ()
|
||||
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
|
||||
return()
|
||||
endif ()
|
||||
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
function(torch_mlir_python_target_compile_options target)
|
||||
target_compile_options(${target} PRIVATE
|
||||
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
|
||||
# Enable RTTI and exceptions.
|
||||
-frtti -fexceptions
|
||||
# Noisy pybind warnings
|
||||
-Wno-unused-value
|
||||
-Wno-covered-switch-default
|
||||
>
|
||||
$<$<CXX_COMPILER_ID:MSVC>:
|
||||
# Enable RTTI and exceptions.
|
||||
/EHsc /GR>
|
||||
)
|
||||
endfunction()
|
7
csrc/tgiccl/CMakeLists.txt
Normal file
7
csrc/tgiccl/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
project(tgiccl)
|
||||
|
||||
set(TGICCL_HEADER_FILES tgiccl.hpp)
|
||||
#set(TGICCL_SOURCE_FILES)
|
||||
|
||||
add_library(tgiccl SHARED ${TGICCL_HEADER_FILES})
|
||||
target_link_libraries(tgiccl nvshmem)
|
22
csrc/tgiccl/tgiccl.hpp
Normal file
22
csrc/tgiccl/tgiccl.hpp
Normal file
@ -0,0 +1,22 @@
|
||||
//
|
||||
// Created by mfuntowicz on 9/25/24.
|
||||
//
|
||||
|
||||
#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||
#define TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
constexpr const char *CLL_BACKEND_NAME = "tgiccl";
|
||||
|
||||
namespace huggingface::tgi {
|
||||
class TgiCcl {
|
||||
private:
|
||||
|
||||
public:
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
#endif //TEXT_GENERATION_INFERENCE_TGICCL_H
|
Loading…
Reference in New Issue
Block a user