Support xccl distributed backend

Starting from `torch>=2.7` XCCL distributed backend is available
for XPU devices (requires torch built with `USE_XCCL=1`).

This commit is verified on Intel Data Center GPU Max with Bloom:
```
text-generation-launcher --sharded true --num-shard 2 \
  --model-id bigscience/bloom-560m
```

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin 2025-02-14 17:08:27 +00:00
parent 5543fdc765
commit fcdb18d1af

View File

@ -3,6 +3,7 @@ import torch
from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
from packaging import version
from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
return self._rank
def _is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False
def initialize_torch_distributed():
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
@ -54,11 +61,20 @@ def initialize_torch_distributed():
device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
device = "cuda"
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120)
elif SYSTEM == "xpu" and _is_xccl_available():
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
device = "xpu"
backend = "xccl"
options = None
else:
device = None
backend = "gloo"
options = None
@ -81,7 +97,8 @@ def initialize_torch_distributed():
pg_options=options,
)
else:
device = torch.device(f"cuda:{RANK}")
if device:
device = torch.device(f"{device}:{RANK}")
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
@ -90,6 +107,7 @@ def initialize_torch_distributed():
pg_options=options,
device_id=device,
)
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
else:
logger.warning("torch.distributed is already initialized.")