diff --git a/my_optims/nccl_test/my_custom_comm.cc b/my_optims/nccl_test/my_custom_comm.cc index 48c65f16..783c9ebb 100644 --- a/my_optims/nccl_test/my_custom_comm.cc +++ b/my_optims/nccl_test/my_custom_comm.cc @@ -6,29 +6,31 @@ #include #include -struct NcclParam { - int rank_{0}; - int world_size_{1}; +struct NcclParam +{ + int rank_{0}; + int world_size_{1}; ncclUniqueId nccl_uid_; - ncclComm_t nccl_comm_ = nullptr; + ncclComm_t nccl_comm_ = nullptr; cudaStream_t stream_; - NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){}; - NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; - NcclParam(NcclParam const& param): - rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){}; + NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){}; + NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){}; + NcclParam(NcclParam const ¶m) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){}; }; -#define NCCLCHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ +#define NCCLCHECK(cmd) \ + do \ + { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) \ + { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -std::tuple init_nccl(int tensor_para_size, int pipeline_para_size) +std::tuple init_nccl(int tensor_para_size, int pipeline_para_size) { // int argc = 0; // char** argv = nullptr; @@ -36,8 +38,8 @@ std::tuple init_nccl(int tensor_para_size, int pipeline_ // // 初始化 MPI // MPI_Init(&argc, &argv); int rank, world_size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank - MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数 + MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank + MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数 // printf("rank:%d, world_size:%d\n", rank, world_size); // 设定gpu @@ -57,7 +59,7 @@ std::tuple init_nccl(int tensor_para_size, int pipeline_ // row = a tensor parallel group, col = a pipeline parallel group. MPI_Comm grid_comm, tp_comm, pp_comm; - int dims[2] = {pipeline_para_size, tensor_para_size}; + int dims[2] = {pipeline_para_size, tensor_para_size}; int periods[2] = {0, 0}; MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); @@ -75,10 +77,12 @@ std::tuple init_nccl(int tensor_para_size, int pipeline_ ncclUniqueId tp_uid; ncclUniqueId pp_uid; // The root of each group creates a nccl uid. - if (tp_rank == 0) { + if (tp_rank == 0) + { NCCLCHECK(ncclGetUniqueId(&tp_uid)); } - if (pp_rank == 0) { + if (pp_rank == 0) + { NCCLCHECK(ncclGetUniqueId(&pp_uid)); } // Broadcast nccl uid to share the same nccl uid across gpus in the same group. @@ -89,31 +93,33 @@ std::tuple init_nccl(int tensor_para_size, int pipeline_ NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank)); NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank)); - tensor_para.world_size_ = tensor_para_size; - tensor_para.rank_ = tp_rank; - tensor_para.nccl_uid_ = tp_uid; - tensor_para.nccl_comm_ = tp_nccl_comm; + tensor_para.world_size_ = tensor_para_size; + tensor_para.rank_ = tp_rank; + tensor_para.nccl_uid_ = tp_uid; + tensor_para.nccl_comm_ = tp_nccl_comm; cudaStreamCreate(&tensor_para.stream_); pipeline_para.world_size_ = pipeline_para_size; - pipeline_para.rank_ = pp_rank; - pipeline_para.nccl_uid_ = pp_uid; - pipeline_para.nccl_comm_ = pp_nccl_comm; + pipeline_para.rank_ = pp_rank; + pipeline_para.nccl_uid_ = pp_uid; + pipeline_para.nccl_comm_ = pp_nccl_comm; + + NcclParam *tensor_para_ptr = &tensor_para; + NcclParam *pipeline_para_ptr = &pipeline_para; - NcclParam* tensor_para_ptr = &tensor_para; - NcclParam* pipeline_para_ptr = &pipeline_para; - return std::make_tuple(tensor_para_ptr, pipeline_para_ptr); } -void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr) +void finalize_nccl(NcclParam *tensor_para_ptr, NcclParam *pipeline_para_ptr) { // 销毁nccl - NcclParam tensor_para = *tensor_para_ptr; + NcclParam tensor_para = *tensor_para_ptr; NcclParam pipeline_para = *pipeline_para_ptr; - if (tensor_para.nccl_comm_ != nullptr) { + if (tensor_para.nccl_comm_ != nullptr) + { ncclCommDestroy(tensor_para.nccl_comm_); } - if (pipeline_para.nccl_comm_ != nullptr) { + if (pipeline_para.nccl_comm_ != nullptr) + { ncclCommDestroy(pipeline_para.nccl_comm_); } // MPI_Finalize(); @@ -122,43 +128,65 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr) ncclDataType_t getNcclDataType(torch::ScalarType torch_type) { ncclDataType_t nccl_type; - if (torch_type == torch::kFloat16) { + if (torch_type == torch::kFloat16) + { nccl_type = ncclHalf; } - else if (torch_type == torch::kFloat32) { + else if (torch_type == torch::kFloat32) + { nccl_type = ncclFloat; } - else if (torch_type == torch::kFloat64) { + else if (torch_type == torch::kFloat64) + { nccl_type = ncclDouble; } - else if (torch_type == torch::kInt32) { + else if (torch_type == torch::kInt32) + { nccl_type = ncclInt32; } - else if (torch_type == torch::kInt64) { + else if (torch_type == torch::kInt64) + { nccl_type = ncclInt64; } - else if (torch_type == torch::kInt8) { + else if (torch_type == torch::kInt8) + { nccl_type = ncclInt8; } - else { + else + { printf("[ERROR] NCCL only support float, half, int \n"); exit(-1); } return nccl_type; } -void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr) +void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr) { - void* data_ptr = tensor.data_ptr(); - size_t count = tensor.numel(); + void *data_ptr = tensor.data_ptr(); + size_t data_size = tensor.numel(); torch::ScalarType torch_type = tensor.scalar_type(); - NcclParam nccl_param = *nccl_param_ptr; + NcclParam nccl_param = *nccl_param_ptr; // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ncclDataType_t nccl_data_type = getNcclDataType(torch_type); NCCLCHECK(ncclGroupStart()); NCCLCHECK(ncclAllReduce( - (const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_)); + (const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_)); + NCCLCHECK(ncclGroupEnd()); +} +void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr) +{ + void *send_ptr = send_tensor.data_ptr(); + void *recv_ptr = recv_tensor.data_ptr(); + size_t data_size = send_tensor.numel(); + + torch::ScalarType torch_type = send_tensor.scalar_type(); + + NcclParam nccl_param = *nccl_param_ptr; + ncclDataType_t nccl_data_type = getNcclDataType(torch_type); + NCCLCHECK(ncclGroupStart()); + NCCLCHECK(ncclAllGather( + (const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_)); NCCLCHECK(ncclGroupEnd()); } @@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("init_nccl", &init_nccl, py::return_value_policy::reference, ""); m.def("finalize_nccl", &finalize_nccl, ""); m.def("custom_allreduce", &custom_allreduce, ""); + m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, ""); } \ No newline at end of file diff --git a/my_optims/nccl_test/test_extension.py b/my_optims/nccl_test/test_extension.py index 277f9531..eafc3231 100644 --- a/my_optims/nccl_test/test_extension.py +++ b/my_optims/nccl_test/test_extension.py @@ -12,11 +12,13 @@ torch.cuda.set_device(device) torch.cuda.set_per_process_memory_fraction(1., device) print(rank, world_size) -t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]], - dtype=torch.float16).to('cuda') - +t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1)) +print(t.shape) +world_out = t.new_empty(1, 8) +print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr)) print(my_custom_comm.custom_allreduce(t, tp_ptr)) -print(t) - +print(t, world_out) torch.cuda.synchronize() + + my_custom_comm.finalize_nccl(tp_ptr, pp_ptr) \ No newline at end of file diff --git a/my_optims/tgi_update/layers.py b/my_optims/tgi_update/layers.py index d0667bda..eb37acad 100644 --- a/my_optims/tgi_update/layers.py +++ b/my_optims/tgi_update/layers.py @@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer): def load(config, prefix: str, weights): if weights.process_group.size() > 1: try: - assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1 weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True except AssertionError: @@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer): torch.mm(input, self.linear.weight.T, out=local_out) - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out return world_out.T + assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl" output = super().forward(input) world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d0667bda..eb37acad 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer): def load(config, prefix: str, weights): if weights.process_group.size() > 1: try: - assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1 weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True except AssertionError: @@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer): torch.mm(input, self.linear.weight.T, out=local_out) - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out return world_out.T + assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl" output = super().forward(input) world_output = [ torch.empty_like(output) for _ in range(self.process_group.size())