add allgather

This commit is contained in:
HuaYZhao 2023-11-29 14:21:42 +08:00
parent 2fd1156d5a
commit 116769a5f5
4 changed files with 97 additions and 60 deletions

View File

@ -6,29 +6,31 @@
#include <string> #include <string>
#include <torch/extension.h> #include <torch/extension.h>
struct NcclParam { struct NcclParam
int rank_{0}; {
int world_size_{1}; int rank_{0};
int world_size_{1};
ncclUniqueId nccl_uid_; ncclUniqueId nccl_uid_;
ncclComm_t nccl_comm_ = nullptr; ncclComm_t nccl_comm_ = nullptr;
cudaStream_t stream_; cudaStream_t stream_;
NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){}; NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){};
NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){};
NcclParam(NcclParam const& param): NcclParam(NcclParam const &param) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
}; };
#define NCCLCHECK(cmd) \ #define NCCLCHECK(cmd) \
do { \ do \
ncclResult_t r = cmd; \ { \
if (r != ncclSuccess) { \ ncclResult_t r = cmd; \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ if (r != ncclSuccess) \
exit(EXIT_FAILURE); \ { \
} \ printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0) } while (0)
std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_para_size) std::tuple<NcclParam *, NcclParam *> init_nccl(int tensor_para_size, int pipeline_para_size)
{ {
// int argc = 0; // int argc = 0;
// char** argv = nullptr; // char** argv = nullptr;
@ -36,8 +38,8 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
// // 初始化 MPI // // 初始化 MPI
// MPI_Init(&argc, &argv); // MPI_Init(&argc, &argv);
int rank, world_size; int rank, world_size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数 MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
// printf("rank:%d, world_size:%d\n", rank, world_size); // printf("rank:%d, world_size:%d\n", rank, world_size);
// 设定gpu // 设定gpu
@ -57,7 +59,7 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
// row = a tensor parallel group, col = a pipeline parallel group. // row = a tensor parallel group, col = a pipeline parallel group.
MPI_Comm grid_comm, tp_comm, pp_comm; MPI_Comm grid_comm, tp_comm, pp_comm;
int dims[2] = {pipeline_para_size, tensor_para_size}; int dims[2] = {pipeline_para_size, tensor_para_size};
int periods[2] = {0, 0}; int periods[2] = {0, 0};
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
@ -75,10 +77,12 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
ncclUniqueId tp_uid; ncclUniqueId tp_uid;
ncclUniqueId pp_uid; ncclUniqueId pp_uid;
// The root of each group creates a nccl uid. // The root of each group creates a nccl uid.
if (tp_rank == 0) { if (tp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&tp_uid)); NCCLCHECK(ncclGetUniqueId(&tp_uid));
} }
if (pp_rank == 0) { if (pp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&pp_uid)); NCCLCHECK(ncclGetUniqueId(&pp_uid));
} }
// Broadcast nccl uid to share the same nccl uid across gpus in the same group. // Broadcast nccl uid to share the same nccl uid across gpus in the same group.
@ -89,31 +93,33 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank)); NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank)); NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
tensor_para.world_size_ = tensor_para_size; tensor_para.world_size_ = tensor_para_size;
tensor_para.rank_ = tp_rank; tensor_para.rank_ = tp_rank;
tensor_para.nccl_uid_ = tp_uid; tensor_para.nccl_uid_ = tp_uid;
tensor_para.nccl_comm_ = tp_nccl_comm; tensor_para.nccl_comm_ = tp_nccl_comm;
cudaStreamCreate(&tensor_para.stream_); cudaStreamCreate(&tensor_para.stream_);
pipeline_para.world_size_ = pipeline_para_size; pipeline_para.world_size_ = pipeline_para_size;
pipeline_para.rank_ = pp_rank; pipeline_para.rank_ = pp_rank;
pipeline_para.nccl_uid_ = pp_uid; pipeline_para.nccl_uid_ = pp_uid;
pipeline_para.nccl_comm_ = pp_nccl_comm; pipeline_para.nccl_comm_ = pp_nccl_comm;
NcclParam *tensor_para_ptr = &tensor_para;
NcclParam *pipeline_para_ptr = &pipeline_para;
NcclParam* tensor_para_ptr = &tensor_para;
NcclParam* pipeline_para_ptr = &pipeline_para;
return std::make_tuple(tensor_para_ptr, pipeline_para_ptr); return std::make_tuple(tensor_para_ptr, pipeline_para_ptr);
} }
void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr) void finalize_nccl(NcclParam *tensor_para_ptr, NcclParam *pipeline_para_ptr)
{ {
// 销毁nccl // 销毁nccl
NcclParam tensor_para = *tensor_para_ptr; NcclParam tensor_para = *tensor_para_ptr;
NcclParam pipeline_para = *pipeline_para_ptr; NcclParam pipeline_para = *pipeline_para_ptr;
if (tensor_para.nccl_comm_ != nullptr) { if (tensor_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(tensor_para.nccl_comm_); ncclCommDestroy(tensor_para.nccl_comm_);
} }
if (pipeline_para.nccl_comm_ != nullptr) { if (pipeline_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(pipeline_para.nccl_comm_); ncclCommDestroy(pipeline_para.nccl_comm_);
} }
// MPI_Finalize(); // MPI_Finalize();
@ -122,43 +128,65 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
ncclDataType_t getNcclDataType(torch::ScalarType torch_type) ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
{ {
ncclDataType_t nccl_type; ncclDataType_t nccl_type;
if (torch_type == torch::kFloat16) { if (torch_type == torch::kFloat16)
{
nccl_type = ncclHalf; nccl_type = ncclHalf;
} }
else if (torch_type == torch::kFloat32) { else if (torch_type == torch::kFloat32)
{
nccl_type = ncclFloat; nccl_type = ncclFloat;
} }
else if (torch_type == torch::kFloat64) { else if (torch_type == torch::kFloat64)
{
nccl_type = ncclDouble; nccl_type = ncclDouble;
} }
else if (torch_type == torch::kInt32) { else if (torch_type == torch::kInt32)
{
nccl_type = ncclInt32; nccl_type = ncclInt32;
} }
else if (torch_type == torch::kInt64) { else if (torch_type == torch::kInt64)
{
nccl_type = ncclInt64; nccl_type = ncclInt64;
} }
else if (torch_type == torch::kInt8) { else if (torch_type == torch::kInt8)
{
nccl_type = ncclInt8; nccl_type = ncclInt8;
} }
else { else
{
printf("[ERROR] NCCL only support float, half, int \n"); printf("[ERROR] NCCL only support float, half, int \n");
exit(-1); exit(-1);
} }
return nccl_type; return nccl_type;
} }
void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr) void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr)
{ {
void* data_ptr = tensor.data_ptr(); void *data_ptr = tensor.data_ptr();
size_t count = tensor.numel(); size_t data_size = tensor.numel();
torch::ScalarType torch_type = tensor.scalar_type(); torch::ScalarType torch_type = tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr; NcclParam nccl_param = *nccl_param_ptr;
// cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ncclDataType_t nccl_data_type = getNcclDataType(torch_type); ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart()); NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllReduce( NCCLCHECK(ncclAllReduce(
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_)); (const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd());
}
void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr)
{
void *send_ptr = send_tensor.data_ptr();
void *recv_ptr = recv_tensor.data_ptr();
size_t data_size = send_tensor.numel();
torch::ScalarType torch_type = send_tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr;
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllGather(
(const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd()); NCCLCHECK(ncclGroupEnd());
} }
@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, ""); m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
m.def("finalize_nccl", &finalize_nccl, ""); m.def("finalize_nccl", &finalize_nccl, "");
m.def("custom_allreduce", &custom_allreduce, ""); m.def("custom_allreduce", &custom_allreduce, "");
m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, "");
} }

View File

@ -12,11 +12,13 @@ torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(1., device) torch.cuda.set_per_process_memory_fraction(1., device)
print(rank, world_size) print(rank, world_size)
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]], t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1))
dtype=torch.float16).to('cuda') print(t.shape)
world_out = t.new_empty(1, 8)
print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr))
print(my_custom_comm.custom_allreduce(t, tp_ptr)) print(my_custom_comm.custom_allreduce(t, tp_ptr))
print(t) print(t, world_out)
torch.cuda.synchronize() torch.cuda.synchronize()
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr) my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1: if weights.process_group.size() > 1:
try: try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True should_gather = True
except AssertionError: except AssertionError:
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor( if USE_CUSTOM_NCCL:
world_out, gather_input, group=self.process_group my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
) else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1: if input.shape[0] == 1:
return world_out return world_out
return world_out.T return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input) output = super().forward(input)
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1: if weights.process_group.size() > 1:
try: try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True should_gather = True
except AssertionError: except AssertionError:
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor( if USE_CUSTOM_NCCL:
world_out, gather_input, group=self.process_group my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
) else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1: if input.shape[0] == 1:
return world_out return world_out
return world_out.T return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input) output = super().forward(input)
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())