This commit is contained in:
Dmitry Rogozhkin 2025-05-30 04:24:48 +02:00 committed by GitHub
commit f6254919b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from packaging import version
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings # Tensor Parallelism settings
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
return self._rank return self._rank
def _is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False
def initialize_torch_distributed(): def initialize_torch_distributed():
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
@ -54,11 +61,20 @@ def initialize_torch_distributed():
device = RANK % torch.cuda.device_count() device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
device = "cuda"
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120) options._timeout = timedelta(seconds=120)
elif SYSTEM == "xpu" and _is_xccl_available():
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
device = "xpu"
backend = "xccl"
options = None
else: else:
device = None
backend = "gloo" backend = "gloo"
options = None options = None
@ -88,7 +104,8 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
) )
else: else:
device = torch.device(f"cuda:{RANK}") if device:
device = torch.device(f"{device}:{RANK}")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
@ -97,6 +114,7 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
device_id=device, device_id=device,
) )
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")