Support xccl distributed backend

Starting from `torch>=2.7` XCCL distributed backend is available
for XPU devices (requires torch built with `USE_XCCL=1`).

This commit is verified on Intel Data Center GPU Max with Bloom:
```
text-generation-launcher --sharded true --num-shard 2 \
  --model-id bigscience/bloom-560m
```

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin 2025-02-14 17:08:27 +00:00
parent 5543fdc765
commit fcdb18d1af

View File

@ -3,6 +3,7 @@ import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from packaging import version
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings # Tensor Parallelism settings
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
return self._rank return self._rank
def _is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False
def initialize_torch_distributed(): def initialize_torch_distributed():
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
@ -54,11 +61,20 @@ def initialize_torch_distributed():
device = RANK % torch.cuda.device_count() device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
device = "cuda"
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120) options._timeout = timedelta(seconds=120)
elif SYSTEM == "xpu" and _is_xccl_available():
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
device = "xpu"
backend = "xccl"
options = None
else: else:
device = None
backend = "gloo" backend = "gloo"
options = None options = None
@ -81,7 +97,8 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
) )
else: else:
device = torch.device(f"cuda:{RANK}") if device:
device = torch.device(f"{device}:{RANK}")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
@ -90,6 +107,7 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
device_id=device, device_id=device,
) )
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")