mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
add allgather
This commit is contained in:
parent
2fd1156d5a
commit
116769a5f5
@ -6,7 +6,8 @@
|
||||
#include <string>
|
||||
#include <torch/extension.h>
|
||||
|
||||
struct NcclParam {
|
||||
struct NcclParam
|
||||
{
|
||||
int rank_{0};
|
||||
int world_size_{1};
|
||||
ncclUniqueId nccl_uid_;
|
||||
@ -15,14 +16,15 @@ struct NcclParam {
|
||||
|
||||
NcclParam() : rank_(0), world_size_(1), nccl_comm_(nullptr){};
|
||||
NcclParam(int rank, int world_size) : rank_(rank), world_size_(world_size){};
|
||||
NcclParam(NcclParam const& param):
|
||||
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
|
||||
NcclParam(NcclParam const ¶m) : rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
|
||||
};
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
do \
|
||||
{ \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
if (r != ncclSuccess) \
|
||||
{ \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
@ -75,10 +77,12 @@ std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_
|
||||
ncclUniqueId tp_uid;
|
||||
ncclUniqueId pp_uid;
|
||||
// The root of each group creates a nccl uid.
|
||||
if (tp_rank == 0) {
|
||||
if (tp_rank == 0)
|
||||
{
|
||||
NCCLCHECK(ncclGetUniqueId(&tp_uid));
|
||||
}
|
||||
if (pp_rank == 0) {
|
||||
if (pp_rank == 0)
|
||||
{
|
||||
NCCLCHECK(ncclGetUniqueId(&pp_uid));
|
||||
}
|
||||
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
|
||||
@ -110,10 +114,12 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
|
||||
// 销毁nccl
|
||||
NcclParam tensor_para = *tensor_para_ptr;
|
||||
NcclParam pipeline_para = *pipeline_para_ptr;
|
||||
if (tensor_para.nccl_comm_ != nullptr) {
|
||||
if (tensor_para.nccl_comm_ != nullptr)
|
||||
{
|
||||
ncclCommDestroy(tensor_para.nccl_comm_);
|
||||
}
|
||||
if (pipeline_para.nccl_comm_ != nullptr) {
|
||||
if (pipeline_para.nccl_comm_ != nullptr)
|
||||
{
|
||||
ncclCommDestroy(pipeline_para.nccl_comm_);
|
||||
}
|
||||
// MPI_Finalize();
|
||||
@ -122,25 +128,32 @@ void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
|
||||
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
|
||||
{
|
||||
ncclDataType_t nccl_type;
|
||||
if (torch_type == torch::kFloat16) {
|
||||
if (torch_type == torch::kFloat16)
|
||||
{
|
||||
nccl_type = ncclHalf;
|
||||
}
|
||||
else if (torch_type == torch::kFloat32) {
|
||||
else if (torch_type == torch::kFloat32)
|
||||
{
|
||||
nccl_type = ncclFloat;
|
||||
}
|
||||
else if (torch_type == torch::kFloat64) {
|
||||
else if (torch_type == torch::kFloat64)
|
||||
{
|
||||
nccl_type = ncclDouble;
|
||||
}
|
||||
else if (torch_type == torch::kInt32) {
|
||||
else if (torch_type == torch::kInt32)
|
||||
{
|
||||
nccl_type = ncclInt32;
|
||||
}
|
||||
else if (torch_type == torch::kInt64) {
|
||||
else if (torch_type == torch::kInt64)
|
||||
{
|
||||
nccl_type = ncclInt64;
|
||||
}
|
||||
else if (torch_type == torch::kInt8) {
|
||||
else if (torch_type == torch::kInt8)
|
||||
{
|
||||
nccl_type = ncclInt8;
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
printf("[ERROR] NCCL only support float, half, int \n");
|
||||
exit(-1);
|
||||
}
|
||||
@ -150,7 +163,7 @@ ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
|
||||
void custom_allreduce(torch::Tensor tensor, NcclParam *nccl_param_ptr)
|
||||
{
|
||||
void *data_ptr = tensor.data_ptr();
|
||||
size_t count = tensor.numel();
|
||||
size_t data_size = tensor.numel();
|
||||
torch::ScalarType torch_type = tensor.scalar_type();
|
||||
|
||||
NcclParam nccl_param = *nccl_param_ptr;
|
||||
@ -158,7 +171,22 @@ void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr)
|
||||
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
NCCLCHECK(ncclAllReduce(
|
||||
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||
(const void *)data_ptr, (void *)data_ptr, data_size, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
void custom_allgather_into_tensor(torch::Tensor recv_tensor, torch::Tensor send_tensor, NcclParam *nccl_param_ptr)
|
||||
{
|
||||
void *send_ptr = send_tensor.data_ptr();
|
||||
void *recv_ptr = recv_tensor.data_ptr();
|
||||
size_t data_size = send_tensor.numel();
|
||||
|
||||
torch::ScalarType torch_type = send_tensor.scalar_type();
|
||||
|
||||
NcclParam nccl_param = *nccl_param_ptr;
|
||||
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
NCCLCHECK(ncclAllGather(
|
||||
(const void *)(send_ptr), (void *)recv_ptr, data_size, nccl_data_type, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
|
||||
@ -168,4 +196,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
|
||||
m.def("finalize_nccl", &finalize_nccl, "");
|
||||
m.def("custom_allreduce", &custom_allreduce, "");
|
||||
m.def("custom_allgather_into_tensor", &custom_allgather_into_tensor, "");
|
||||
}
|
@ -12,11 +12,13 @@ torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(1., device)
|
||||
print(rank, world_size)
|
||||
|
||||
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
|
||||
dtype=torch.float16).to('cuda')
|
||||
|
||||
t = torch.tensor([rank+3.1]*4, dtype=torch.float16).to('cuda').view((1,-1))
|
||||
print(t.shape)
|
||||
world_out = t.new_empty(1, 8)
|
||||
print(my_custom_comm.custom_allgather_into_tensor(world_out, t, tp_ptr))
|
||||
print(my_custom_comm.custom_allreduce(t, tp_ptr))
|
||||
print(t)
|
||||
|
||||
print(t, world_out)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)
|
@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
@ -402,6 +401,9 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
@ -410,6 +412,7 @@ class TensorParallelHead(SuperLayer):
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
|
@ -359,7 +359,6 @@ class TensorParallelHead(SuperLayer):
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
@ -402,6 +401,9 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allgather_into_tensor(world_out, gather_input, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
@ -410,6 +412,7 @@ class TensorParallelHead(SuperLayer):
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
assert USE_CUSTOM_NCCL == 0, "should not be here for llama with custom nccl"
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
|
Loading…
Reference in New Issue
Block a user