mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge fcdb18d1af
into 6b6e30a6f6
This commit is contained in:
commit
f6254919b3
@ -3,6 +3,7 @@ import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
# Tensor Parallelism settings
|
||||
@ -45,6 +46,12 @@ class FakeGroup(ProcessGroup):
|
||||
return self._rank
|
||||
|
||||
|
||||
def _is_xccl_available():
|
||||
if version.parse(torch.__version__).release >= version.parse("2.7").release:
|
||||
return torch.distributed.distributed_c10d.is_xccl_available()
|
||||
return False
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
if torch.cuda.is_available():
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
@ -54,11 +61,20 @@ def initialize_torch_distributed():
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
device = "cuda"
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=120)
|
||||
elif SYSTEM == "xpu" and _is_xccl_available():
|
||||
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.xpu.device_count()
|
||||
torch.xpu.set_device(device)
|
||||
device = "xpu"
|
||||
backend = "xccl"
|
||||
options = None
|
||||
else:
|
||||
device = None
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
@ -88,7 +104,8 @@ def initialize_torch_distributed():
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
device = torch.device(f"cuda:{RANK}")
|
||||
if device:
|
||||
device = torch.device(f"{device}:{RANK}")
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=WORLD_SIZE,
|
||||
@ -97,6 +114,7 @@ def initialize_torch_distributed():
|
||||
pg_options=options,
|
||||
device_id=device,
|
||||
)
|
||||
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user