add allgather

This commit is contained in:
HuaYZhao 2023-11-29 14:21:42 +08:00
parent 2fd1156d5a
commit 116769a5f5
4 changed files with 97 additions and 60 deletions

View File

@ -6,29 +6,31 @@
#include <string>
#include <torch/extension.h>
struct NcclParam {
int rank_{0};
int world_size_{1};
struct NcclParam
{
int rank_{0};
int world_size_{1};
ncclUniqueId nccl_uid_;
ncclComm_t nccl_comm_ = nullptr;
ncclComm_t nccl_comm_ = nullptr;
cudaStream_t stream_;
NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){};
NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){};
NcclParam(NcclParam const& param):
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){};
NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){};
NcclParam(NcclParam const &param) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
};
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
#define NCCLCHECK(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_para_size)
std::tuple<NcclParam *, NcclParam *> init_nccl(int tensor_para_size, int pipeline_para_size)
{
// int argc = 0;
// char** argv = nullptr;
@ -36,8 +38,8 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
// // 初始化 MPI
// MPI_Init(&argc, &argv);
int rank, world_size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
// printf("rank:%d, world_size:%d\n", rank, world_size);
// 设定gpu
@ -57,7 +59,7 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
// row = a tensor parallel group, col = a pipeline parallel group.
MPI_Comm grid_comm, tp_comm, pp_comm;
int dims[2] = {pipeline_para_size, tensor_para_size};
int dims[2] = {pipeline_para_size, tensor_para_size};
int periods[2] = {0, 0};
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
@ -75,10 +77,12 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
ncclUniqueId tp_uid;
ncclUniqueId pp_uid;
// The root of each group creates a nccl uid.
if (tp_rank == 0) {
if (tp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&tp_uid));
}
if (pp_rank == 0) {
if (pp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&pp_uid));
}
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
@ -89,31 +93,33 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
tensor_para.world_size_ = tensor_para_size;
tensor_para.rank_ = tp_rank;
tensor_para.nccl_uid_ = tp_uid;
tensor_para.nccl_comm_ = tp_nccl_comm;
tensor_para.world_size_ = tensor_para_size;
tensor_para.rank_ = tp_rank;
tensor_para.nccl_uid_ = tp_uid;
tensor_para.nccl_comm_ = tp_nccl_comm;
cudaStreamCreate(&tensor_para.stream_);
pipeline_para.world_size_ = pipeline_para_size;
pipeline_para.rank_ = pp_rank;
pipeline_para.nccl_uid_ = pp_uid;
pipeline_para.nccl_comm_ = pp_nccl_comm;
pipeline_para.rank_ = pp_rank;
pipeline_para.nccl_uid_ = pp_uid;
pipeline_para.nccl_comm_ = pp_nccl_comm;
NcclParam *tensor_para_ptr = &tensor_para;
NcclParam *pipeline_para_ptr = &pipeline_para;
NcclParam* tensor_para_ptr = &tensor_para;
NcclParam* pipeline_para_ptr = &pipeline_para;
return std::make_tuple(tensor_para_ptr, pipeline_para_ptr);
}
void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
void finalize_nccl(NcclParam *tensor_para_ptr, NcclParam *pipeline_para_ptr)
{
// 销毁nccl
NcclParam tensor_para = *tensor_para_ptr;
NcclParam tensor_para = *tensor_para_ptr;
NcclParam pipeline_para = *pipeline_para_ptr;
if (tensor_para.nccl_comm_ != nullptr) {
if (tensor_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(tensor_para.nccl_comm_);
}
if (pipeline_para.nccl_comm_ != nullptr) {
if (pipeline_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(pipeline_para.nccl_comm_);
}
// MPI_Finalize();
@ -122,43 +128,65 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
{
ncclDataType_t nccl_type;
if (torch_type == torch::kFloat16) {
if (torch_type == torch::kFloat16)
{
nccl_type = ncclHalf;
}
else if (torch_type == torch::kFloat32) {
else if (torch_type == torch::kFloat32)
{
nccl_type = ncclFloat;
}
else if (torch_type == torch::kFloat64) {
else if (torch_type == torch::kFloat64)
{
nccl_type = ncclDouble;
}
else if (torch_type == torch::kInt32) {
else if (torch_type == torch::kInt32)
{
nccl_type = ncclInt32;
}
else if (torch_type == torch::kInt64) {
else if (torch_type == torch::kInt64)
{
nccl_type = ncclInt64;
}
else if (torch_type == torch::kInt8) {
else if (torch_type == torch::kInt8)
{
nccl_type = ncclInt8;
}
else {
else
{
printf("[ERROR] NCCL only support float, half, int \n");
exit(-1);
}
return nccl_type;
}
void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr)
void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr)
{
void* data_ptr = tensor.data_ptr();
size_t count = tensor.numel();
void *data_ptr = tensor.data_ptr();
size_t data_size = tensor.numel();
torch::ScalarType torch_type = tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr;
NcclParam nccl_param = *nccl_param_ptr;
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllReduce(
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
(const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd());
}
void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr)
{
void *send_ptr = send_tensor.data_ptr();
void *recv_ptr = recv_tensor.data_ptr();
size_t data_size = send_tensor.numel();
torch::ScalarType torch_type = send_tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr;
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllGather(
(const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd());
}
@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
m.def("finalize_nccl", &finalize_nccl, "");
m.def("custom_allreduce", &custom_allreduce, "");
m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, "");
}

View File

@ -12,11 +12,13 @@ torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(1., device)
print(rank, world_size)
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
dtype=torch.float16).to('cuda')
t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1))
print(t.shape)
world_out = t.new_empty(1, 8)
print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr))
print(my_custom_comm.custom_allreduce(t, tp_ptr))
print(t)
print(t, world_out)
torch.cuda.synchronize()
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())