This commit is contained in:
Dmitry Rogozhkin 2025-05-30 04:24:48 +02:00 committed by GitHub
commit f6254919b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import torch
from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
from packaging import version
from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
return self._rank
def _is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False
def initialize_torch_distributed():
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
@ -54,11 +61,20 @@ def initialize_torch_distributed():
device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
device = "cuda"
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120)
elif SYSTEM == "xpu" and _is_xccl_available():
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
device = "xpu"
backend = "xccl"
options = None
else:
device = None
backend = "gloo"
options = None
@ -88,7 +104,8 @@ def initialize_torch_distributed():
pg_options=options,
)
else:
device = torch.device(f"cuda:{RANK}")
if device:
device = torch.device(f"{device}:{RANK}")
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
@ -97,6 +114,7 @@ def initialize_torch_distributed():
pg_options=options,
device_id=device,
)
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
else:
logger.warning("torch.distributed is already initialized.")