diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 613c4784b..0e59a091c 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger +from packaging import version from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings @@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup): return self._rank +def _is_xccl_available(): + if version.parse(torch.__version__).release >= version.parse("2.7").release: + return torch.distributed.distributed_c10d.is_xccl_available() + return False + + def initialize_torch_distributed(): if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL @@ -54,11 +61,20 @@ def initialize_torch_distributed(): device = RANK % torch.cuda.device_count() torch.cuda.set_device(device) torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + device = "cuda" backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True options._timeout = timedelta(seconds=120) + elif SYSTEM == "xpu" and _is_xccl_available(): + assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu" + device = RANK % torch.xpu.device_count() + torch.xpu.set_device(device) + device = "xpu" + backend = "xccl" + options = None else: + device = None backend = "gloo" options = None @@ -81,7 +97,8 @@ def initialize_torch_distributed(): pg_options=options, ) else: - device = torch.device(f"cuda:{RANK}") + if device: + device = torch.device(f"{device}:{RANK}") torch.distributed.init_process_group( backend=backend, world_size=WORLD_SIZE, @@ -90,6 +107,7 @@ def initialize_torch_distributed(): pg_options=options, device_id=device, ) + logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}") else: logger.warning("torch.distributed is already initialized.")