2023-02-14 12:02:16 +00:00
|
|
|
import os
|
|
|
|
import torch
|
2025-01-13 10:11:31 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2023-02-14 12:02:16 +00:00
|
|
|
from datetime import timedelta
|
2023-06-26 10:32:54 +00:00
|
|
|
from loguru import logger
|
2024-06-25 11:20:57 +00:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-07-24 09:43:58 +00:00
|
|
|
# Tensor Parallelism settings
|
|
|
|
RANK = int(os.getenv("RANK", "0"))
|
|
|
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
|
|
|
|
|
|
|
# CUDA memory fraction
|
|
|
|
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
|
|
|
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
class FakeBarrier:
|
|
|
|
def wait(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2025-01-13 10:11:31 +00:00
|
|
|
class FakeGroup(ProcessGroup):
|
2023-06-08 12:51:52 +00:00
|
|
|
def __init__(self, rank, size):
|
|
|
|
self._rank = rank
|
|
|
|
self._size = size
|
2025-01-13 10:11:31 +00:00
|
|
|
super().__init__(rank, size)
|
2023-06-08 12:51:52 +00:00
|
|
|
|
|
|
|
def allreduce(self, *args, **kwargs):
|
|
|
|
return FakeBarrier()
|
|
|
|
|
|
|
|
def allgather(self, inputs, local_tensor, **kwargs):
|
|
|
|
assert (
|
|
|
|
len(inputs[0]) == len(local_tensor) == 1
|
|
|
|
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
|
|
|
for input_ in inputs:
|
|
|
|
input_[0].data = local_tensor[0].data
|
|
|
|
return FakeBarrier()
|
|
|
|
|
|
|
|
def barrier(self, *args, **kwargs):
|
|
|
|
return FakeBarrier()
|
|
|
|
|
|
|
|
def size(self):
|
|
|
|
return self._size
|
|
|
|
|
|
|
|
def rank(self):
|
|
|
|
return self._rank
|
|
|
|
|
|
|
|
|
2023-02-14 12:02:16 +00:00
|
|
|
def initialize_torch_distributed():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
from torch.distributed import ProcessGroupNCCL
|
|
|
|
|
|
|
|
# Set the device id.
|
2023-07-24 09:43:58 +00:00
|
|
|
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
|
|
|
device = RANK % torch.cuda.device_count()
|
2023-02-14 12:02:16 +00:00
|
|
|
torch.cuda.set_device(device)
|
2023-07-24 09:43:58 +00:00
|
|
|
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
2023-02-14 12:02:16 +00:00
|
|
|
backend = "nccl"
|
|
|
|
options = ProcessGroupNCCL.Options()
|
|
|
|
options.is_high_priority_stream = True
|
2024-07-20 17:02:04 +00:00
|
|
|
options._timeout = timedelta(seconds=120)
|
2023-02-14 12:02:16 +00:00
|
|
|
else:
|
2024-06-25 10:21:29 +00:00
|
|
|
backend = "gloo"
|
2023-02-14 12:02:16 +00:00
|
|
|
options = None
|
|
|
|
|
2023-07-24 09:43:58 +00:00
|
|
|
if WORLD_SIZE == 1:
|
|
|
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
2023-06-08 12:51:52 +00:00
|
|
|
else:
|
|
|
|
if os.getenv("DEBUG", None) == "1":
|
2023-07-24 09:43:58 +00:00
|
|
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
2023-06-26 10:32:54 +00:00
|
|
|
|
|
|
|
if not torch.distributed.is_initialized():
|
|
|
|
# Call the init process.
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
|
2025-04-15 09:08:01 +00:00
|
|
|
if torch.xpu.is_available():
|
|
|
|
assert (
|
|
|
|
WORLD_SIZE <= torch.xpu.device_count()
|
|
|
|
), "Each process is one xpu"
|
|
|
|
device = RANK % torch.xpu.device_count()
|
|
|
|
torch.xpu.set_device(device)
|
|
|
|
|
2024-06-25 10:21:29 +00:00
|
|
|
ipex.distributed.init_process_group(
|
|
|
|
backend="ccl",
|
|
|
|
world_size=WORLD_SIZE,
|
|
|
|
rank=RANK,
|
2024-07-20 17:02:04 +00:00
|
|
|
timeout=timedelta(seconds=120),
|
2024-06-25 10:21:29 +00:00
|
|
|
pg_options=options,
|
|
|
|
)
|
|
|
|
else:
|
2025-01-30 15:40:25 +00:00
|
|
|
device = torch.device(f"cuda:{RANK}")
|
2024-06-25 10:21:29 +00:00
|
|
|
torch.distributed.init_process_group(
|
|
|
|
backend=backend,
|
|
|
|
world_size=WORLD_SIZE,
|
|
|
|
rank=RANK,
|
2024-07-20 17:02:04 +00:00
|
|
|
timeout=timedelta(seconds=120),
|
2024-06-25 10:21:29 +00:00
|
|
|
pg_options=options,
|
2025-01-30 15:40:25 +00:00
|
|
|
device_id=device,
|
2024-06-25 10:21:29 +00:00
|
|
|
)
|
2023-06-26 10:32:54 +00:00
|
|
|
else:
|
|
|
|
logger.warning("torch.distributed is already initialized.")
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-07-24 09:43:58 +00:00
|
|
|
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|