mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
add allgather
This commit is contained in:
parent
2fd1156d5a
commit
116769a5f5
@ -6,29 +6,31 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
struct NcclParam {
|
struct NcclParam
|
||||||
int rank_{0};
|
{
|
||||||
int world_size_{1};
|
int rank_{0};
|
||||||
|
int world_size_{1};
|
||||||
ncclUniqueId nccl_uid_;
|
ncclUniqueId nccl_uid_;
|
||||||
ncclComm_t nccl_comm_ = nullptr;
|
ncclComm_t nccl_comm_ = nullptr;
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
|
||||||
NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){};
|
NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){};
|
||||||
NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){};
|
NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){};
|
||||||
NcclParam(NcclParam const& param):
|
NcclParam(NcclParam const ¶m) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
|
||||||
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#define NCCLCHECK(cmd) \
|
#define NCCLCHECK(cmd) \
|
||||||
do { \
|
do \
|
||||||
ncclResult_t r = cmd; \
|
{ \
|
||||||
if (r != ncclSuccess) { \
|
ncclResult_t r = cmd; \
|
||||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
|
if (r != ncclSuccess) \
|
||||||
exit(EXIT_FAILURE); \
|
{ \
|
||||||
} \
|
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_para_size)
|
std::tuple<NcclParam *, NcclParam *> init_nccl(int tensor_para_size, int pipeline_para_size)
|
||||||
{
|
{
|
||||||
// int argc = 0;
|
// int argc = 0;
|
||||||
// char** argv = nullptr;
|
// char** argv = nullptr;
|
||||||
@ -36,8 +38,8 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
|
|||||||
// // 初始化 MPI
|
// // 初始化 MPI
|
||||||
// MPI_Init(&argc, &argv);
|
// MPI_Init(&argc, &argv);
|
||||||
int rank, world_size;
|
int rank, world_size;
|
||||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
|
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
|
||||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
|
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
|
||||||
// printf("rank:%d, world_size:%d\n", rank, world_size);
|
// printf("rank:%d, world_size:%d\n", rank, world_size);
|
||||||
|
|
||||||
// 设定gpu
|
// 设定gpu
|
||||||
@ -57,7 +59,7 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
|
|||||||
// row = a tensor parallel group, col = a pipeline parallel group.
|
// row = a tensor parallel group, col = a pipeline parallel group.
|
||||||
MPI_Comm grid_comm, tp_comm, pp_comm;
|
MPI_Comm grid_comm, tp_comm, pp_comm;
|
||||||
|
|
||||||
int dims[2] = {pipeline_para_size, tensor_para_size};
|
int dims[2] = {pipeline_para_size, tensor_para_size};
|
||||||
int periods[2] = {0, 0};
|
int periods[2] = {0, 0};
|
||||||
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
|
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
|
||||||
|
|
||||||
@ -75,10 +77,12 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
|
|||||||
ncclUniqueId tp_uid;
|
ncclUniqueId tp_uid;
|
||||||
ncclUniqueId pp_uid;
|
ncclUniqueId pp_uid;
|
||||||
// The root of each group creates a nccl uid.
|
// The root of each group creates a nccl uid.
|
||||||
if (tp_rank == 0) {
|
if (tp_rank == 0)
|
||||||
|
{
|
||||||
NCCLCHECK(ncclGetUniqueId(&tp_uid));
|
NCCLCHECK(ncclGetUniqueId(&tp_uid));
|
||||||
}
|
}
|
||||||
if (pp_rank == 0) {
|
if (pp_rank == 0)
|
||||||
|
{
|
||||||
NCCLCHECK(ncclGetUniqueId(&pp_uid));
|
NCCLCHECK(ncclGetUniqueId(&pp_uid));
|
||||||
}
|
}
|
||||||
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
|
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
|
||||||
@ -89,31 +93,33 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
|
|||||||
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
|
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
|
||||||
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
|
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
|
||||||
|
|
||||||
tensor_para.world_size_ = tensor_para_size;
|
tensor_para.world_size_ = tensor_para_size;
|
||||||
tensor_para.rank_ = tp_rank;
|
tensor_para.rank_ = tp_rank;
|
||||||
tensor_para.nccl_uid_ = tp_uid;
|
tensor_para.nccl_uid_ = tp_uid;
|
||||||
tensor_para.nccl_comm_ = tp_nccl_comm;
|
tensor_para.nccl_comm_ = tp_nccl_comm;
|
||||||
cudaStreamCreate(&tensor_para.stream_);
|
cudaStreamCreate(&tensor_para.stream_);
|
||||||
pipeline_para.world_size_ = pipeline_para_size;
|
pipeline_para.world_size_ = pipeline_para_size;
|
||||||
pipeline_para.rank_ = pp_rank;
|
pipeline_para.rank_ = pp_rank;
|
||||||
pipeline_para.nccl_uid_ = pp_uid;
|
pipeline_para.nccl_uid_ = pp_uid;
|
||||||
pipeline_para.nccl_comm_ = pp_nccl_comm;
|
pipeline_para.nccl_comm_ = pp_nccl_comm;
|
||||||
|
|
||||||
|
NcclParam *tensor_para_ptr = &tensor_para;
|
||||||
|
NcclParam *pipeline_para_ptr = &pipeline_para;
|
||||||
|
|
||||||
NcclParam* tensor_para_ptr = &tensor_para;
|
|
||||||
NcclParam* pipeline_para_ptr = &pipeline_para;
|
|
||||||
|
|
||||||
return std::make_tuple(tensor_para_ptr, pipeline_para_ptr);
|
return std::make_tuple(tensor_para_ptr, pipeline_para_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
|
void finalize_nccl(NcclParam *tensor_para_ptr, NcclParam *pipeline_para_ptr)
|
||||||
{
|
{
|
||||||
// 销毁nccl
|
// 销毁nccl
|
||||||
NcclParam tensor_para = *tensor_para_ptr;
|
NcclParam tensor_para = *tensor_para_ptr;
|
||||||
NcclParam pipeline_para = *pipeline_para_ptr;
|
NcclParam pipeline_para = *pipeline_para_ptr;
|
||||||
if (tensor_para.nccl_comm_ != nullptr) {
|
if (tensor_para.nccl_comm_ != nullptr)
|
||||||
|
{
|
||||||
ncclCommDestroy(tensor_para.nccl_comm_);
|
ncclCommDestroy(tensor_para.nccl_comm_);
|
||||||
}
|
}
|
||||||
if (pipeline_para.nccl_comm_ != nullptr) {
|
if (pipeline_para.nccl_comm_ != nullptr)
|
||||||
|
{
|
||||||
ncclCommDestroy(pipeline_para.nccl_comm_);
|
ncclCommDestroy(pipeline_para.nccl_comm_);
|
||||||
}
|
}
|
||||||
// MPI_Finalize();
|
// MPI_Finalize();
|
||||||
@ -122,43 +128,65 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
|
|||||||
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
|
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
|
||||||
{
|
{
|
||||||
ncclDataType_t nccl_type;
|
ncclDataType_t nccl_type;
|
||||||
if (torch_type == torch::kFloat16) {
|
if (torch_type == torch::kFloat16)
|
||||||
|
{
|
||||||
nccl_type = ncclHalf;
|
nccl_type = ncclHalf;
|
||||||
}
|
}
|
||||||
else if (torch_type == torch::kFloat32) {
|
else if (torch_type == torch::kFloat32)
|
||||||
|
{
|
||||||
nccl_type = ncclFloat;
|
nccl_type = ncclFloat;
|
||||||
}
|
}
|
||||||
else if (torch_type == torch::kFloat64) {
|
else if (torch_type == torch::kFloat64)
|
||||||
|
{
|
||||||
nccl_type = ncclDouble;
|
nccl_type = ncclDouble;
|
||||||
}
|
}
|
||||||
else if (torch_type == torch::kInt32) {
|
else if (torch_type == torch::kInt32)
|
||||||
|
{
|
||||||
nccl_type = ncclInt32;
|
nccl_type = ncclInt32;
|
||||||
}
|
}
|
||||||
else if (torch_type == torch::kInt64) {
|
else if (torch_type == torch::kInt64)
|
||||||
|
{
|
||||||
nccl_type = ncclInt64;
|
nccl_type = ncclInt64;
|
||||||
}
|
}
|
||||||
else if (torch_type == torch::kInt8) {
|
else if (torch_type == torch::kInt8)
|
||||||
|
{
|
||||||
nccl_type = ncclInt8;
|
nccl_type = ncclInt8;
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
|
{
|
||||||
printf("[ERROR] NCCL only support float, half, int \n");
|
printf("[ERROR] NCCL only support float, half, int \n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nccl_type;
|
return nccl_type;
|
||||||
}
|
}
|
||||||
void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr)
|
void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr)
|
||||||
{
|
{
|
||||||
void* data_ptr = tensor.data_ptr();
|
void *data_ptr = tensor.data_ptr();
|
||||||
size_t count = tensor.numel();
|
size_t data_size = tensor.numel();
|
||||||
torch::ScalarType torch_type = tensor.scalar_type();
|
torch::ScalarType torch_type = tensor.scalar_type();
|
||||||
|
|
||||||
NcclParam nccl_param = *nccl_param_ptr;
|
NcclParam nccl_param = *nccl_param_ptr;
|
||||||
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
||||||
NCCLCHECK(ncclGroupStart());
|
NCCLCHECK(ncclGroupStart());
|
||||||
NCCLCHECK(ncclAllReduce(
|
NCCLCHECK(ncclAllReduce(
|
||||||
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
|
(const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||||
|
NCCLCHECK(ncclGroupEnd());
|
||||||
|
}
|
||||||
|
void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr)
|
||||||
|
{
|
||||||
|
void *send_ptr = send_tensor.data_ptr();
|
||||||
|
void *recv_ptr = recv_tensor.data_ptr();
|
||||||
|
size_t data_size = send_tensor.numel();
|
||||||
|
|
||||||
|
torch::ScalarType torch_type = send_tensor.scalar_type();
|
||||||
|
|
||||||
|
NcclParam nccl_param = *nccl_param_ptr;
|
||||||
|
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
||||||
|
NCCLCHECK(ncclGroupStart());
|
||||||
|
NCCLCHECK(ncclAllGather(
|
||||||
|
(const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||||
NCCLCHECK(ncclGroupEnd());
|
NCCLCHECK(ncclGroupEnd());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||||||
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
|
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
|
||||||
m.def("finalize_nccl", &finalize_nccl, "");
|
m.def("finalize_nccl", &finalize_nccl, "");
|
||||||
m.def("custom_allreduce", &custom_allreduce, "");
|
m.def("custom_allreduce", &custom_allreduce, "");
|
||||||
|
m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, "");
|
||||||
}
|
}
|
@ -12,11 +12,13 @@ torch.cuda.set_device(device)
|
|||||||
torch.cuda.set_per_process_memory_fraction(1., device)
|
torch.cuda.set_per_process_memory_fraction(1., device)
|
||||||
print(rank, world_size)
|
print(rank, world_size)
|
||||||
|
|
||||||
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
|
t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1))
|
||||||
dtype=torch.float16).to('cuda')
|
print(t.shape)
|
||||||
|
world_out = t.new_empty(1, 8)
|
||||||
|
print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr))
|
||||||
print(my_custom_comm.custom_allreduce(t, tp_ptr))
|
print(my_custom_comm.custom_allreduce(t, tp_ptr))
|
||||||
print(t)
|
print(t, world_out)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)
|
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)
|
@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
|
|||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
if weights.process_group.size() > 1:
|
if weights.process_group.size() > 1:
|
||||||
try:
|
try:
|
||||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
should_gather = True
|
should_gather = True
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
if USE_CUSTOM_NCCL:
|
||||||
world_out, gather_input, group=self.process_group
|
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
|
||||||
)
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
return world_out.T
|
return world_out.T
|
||||||
|
|
||||||
|
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
|
|||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
if weights.process_group.size() > 1:
|
if weights.process_group.size() > 1:
|
||||||
try:
|
try:
|
||||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
should_gather = True
|
should_gather = True
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
@ -402,14 +401,18 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
if USE_CUSTOM_NCCL:
|
||||||
world_out, gather_input, group=self.process_group
|
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
|
||||||
)
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
return world_out.T
|
return world_out.T
|
||||||
|
|
||||||
|
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
Loading…
Reference in New Issue
Block a user