mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Support xccl distributed backend
Starting from `torch>=2.7` XCCL distributed backend is available for XPU devices (requires torch built with `USE_XCCL=1`). This commit is verified on Intel Data Center GPU Max with Bloom: ``` text-generation-launcher --sharded true --num-shard 2 \ --model-id bigscience/bloom-560m ``` Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
parent
5543fdc765
commit
fcdb18d1af
@ -3,6 +3,7 @@ import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
# Tensor Parallelism settings
|
||||
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
|
||||
return self._rank
|
||||
|
||||
|
||||
def _is_xccl_available():
|
||||
if version.parse(torch.__version__).release >= version.parse("2.7").release:
|
||||
return torch.distributed.distributed_c10d.is_xccl_available()
|
||||
return False
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
if torch.cuda.is_available():
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
@ -54,11 +61,20 @@ def initialize_torch_distributed():
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
device = "cuda"
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=120)
|
||||
elif SYSTEM == "xpu" and _is_xccl_available():
|
||||
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.xpu.device_count()
|
||||
torch.xpu.set_device(device)
|
||||
device = "xpu"
|
||||
backend = "xccl"
|
||||
options = None
|
||||
else:
|
||||
device = None
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
@ -81,7 +97,8 @@ def initialize_torch_distributed():
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
device = torch.device(f"cuda:{RANK}")
|
||||
if device:
|
||||
device = torch.device(f"{device}:{RANK}")
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=WORLD_SIZE,
|
||||
@ -90,6 +107,7 @@ def initialize_torch_distributed():
|
||||
pg_options=options,
|
||||
device_id=device,
|
||||
)
|
||||
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user