add allgather

This commit is contained in:
HuaYZhao 2023-11-29 14:21:42 +08:00
parent 2fd1156d5a
commit 116769a5f5
4 changed files with 97 additions and 60 deletions

View File

@ -6,7 +6,8 @@
#include <string>
#include <torch/extension.h>
struct NcclParam {
struct NcclParam
{
int rank_{0};
int world_size_{1};
ncclUniqueId nccl_uid_;
@ -15,14 +16,15 @@ struct NcclParam {
NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){};
NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){};
NcclParam(NcclParam const& param):
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
NcclParam(NcclParam const &param) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
};
#define NCCLCHECK(cmd) \
do { \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
@ -75,10 +77,12 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
ncclUniqueId tp_uid;
ncclUniqueId pp_uid;
// The root of each group creates a nccl uid.
if (tp_rank == 0) {
if (tp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&tp_uid));
}
if (pp_rank == 0) {
if (pp_rank == 0)
{
NCCLCHECK(ncclGetUniqueId(&pp_uid));
}
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
@ -110,10 +114,12 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
// 销毁nccl
NcclParam tensor_para = *tensor_para_ptr;
NcclParam pipeline_para = *pipeline_para_ptr;
if (tensor_para.nccl_comm_ != nullptr) {
if (tensor_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(tensor_para.nccl_comm_);
}
if (pipeline_para.nccl_comm_ != nullptr) {
if (pipeline_para.nccl_comm_ != nullptr)
{
ncclCommDestroy(pipeline_para.nccl_comm_);
}
// MPI_Finalize();
@ -122,25 +128,32 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
{
ncclDataType_t nccl_type;
if (torch_type == torch::kFloat16) {
if (torch_type == torch::kFloat16)
{
nccl_type = ncclHalf;
}
else if (torch_type == torch::kFloat32) {
else if (torch_type == torch::kFloat32)
{
nccl_type = ncclFloat;
}
else if (torch_type == torch::kFloat64) {
else if (torch_type == torch::kFloat64)
{
nccl_type = ncclDouble;
}
else if (torch_type == torch::kInt32) {
else if (torch_type == torch::kInt32)
{
nccl_type = ncclInt32;
}
else if (torch_type == torch::kInt64) {
else if (torch_type == torch::kInt64)
{
nccl_type = ncclInt64;
}
else if (torch_type == torch::kInt8) {
else if (torch_type == torch::kInt8)
{
nccl_type = ncclInt8;
}
else {
else
{
printf("[ERROR] NCCL only support float, half, int \n");
exit(-1);
}
@ -150,7 +163,7 @@ ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr)
{
void *data_ptr = tensor.data_ptr();
size_t count = tensor.numel();
size_t data_size = tensor.numel();
torch::ScalarType torch_type = tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr;
@ -158,7 +171,22 @@ void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr)
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllReduce(
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
(const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd());
}
void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr)
{
void *send_ptr = send_tensor.data_ptr();
void *recv_ptr = recv_tensor.data_ptr();
size_t data_size = send_tensor.numel();
torch::ScalarType torch_type = send_tensor.scalar_type();
NcclParam nccl_param = *nccl_param_ptr;
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclAllGather(
(const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_));
NCCLCHECK(ncclGroupEnd());
}
@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
m.def("finalize_nccl", &finalize_nccl, "");
m.def("custom_allreduce", &custom_allreduce, "");
m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, "");
}

View File

@ -12,11 +12,13 @@ torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(1., device)
print(rank, world_size)
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
dtype=torch.float16).to('cuda')
t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1))
print(t.shape)
world_out = t.new_empty(1, 8)
print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr))
print(my_custom_comm.custom_allreduce(t, tp_ptr))
print(t)
print(t, world_out)
torch.cuda.synchronize()
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
@ -402,6 +401,9 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
@ -410,6 +412,7 @@ class TensorParallelHead(SuperLayer):
return world_out
return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())

View File

@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
@ -402,6 +401,9 @@ class TensorParallelHead(SuperLayer):
torch.mm(input, self.linear.weight.T, out=local_out)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
@ -410,6 +412,7 @@ class TensorParallelHead(SuperLayer):
return world_out
return world_out.T
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())