From fcdb18d1af9aaa6fef2196b789ef45816373bb19 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 14 Feb 2025 17:08:27 +0000 Subject: [PATCH] Support xccl distributed backend Starting from `torch>=2.7` XCCL distributed backend is available for XPU devices (requires torch built with `USE_XCCL=1`). This commit is verified on Intel Data Center GPU Max with Bloom: ``` text-generation-launcher --sharded true --num-shard 2 \ --model-id bigscience/bloom-560m ``` Signed-off-by: Dmitry Rogozhkin --- server/text_generation_server/utils/dist.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 613c4784..0e59a091 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger +from packaging import version from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings @@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup): return self._rank +def _is_xccl_available(): + if version.parse(torch.__version__).release >= version.parse("2.7").release: + return torch.distributed.distributed_c10d.is_xccl_available() + return False + + def initialize_torch_distributed(): if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL @@ -54,11 +61,20 @@ def initialize_torch_distributed(): device = RANK % torch.cuda.device_count() torch.cuda.set_device(device) torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + device = "cuda" backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True options._timeout = timedelta(seconds=120) + elif SYSTEM == "xpu" and _is_xccl_available(): + assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu" + device = RANK % torch.xpu.device_count() + torch.xpu.set_device(device) + device = "xpu" + backend = "xccl" + options = None else: + device = None backend = "gloo" options = None @@ -81,7 +97,8 @@ def initialize_torch_distributed(): pg_options=options, ) else: - device = torch.device(f"cuda:{RANK}") + if device: + device = torch.device(f"{device}:{RANK}") torch.distributed.init_process_group( backend=backend, world_size=WORLD_SIZE, @@ -90,6 +107,7 @@ def initialize_torch_distributed(): pg_options=options, device_id=device, ) + logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}") else: logger.warning("torch.distributed is already initialized.")