simplify initialize_torch_distributed()

This commit is contained in:
ur4t 2024-07-04 14:50:12 +08:00
parent 5ad41aa2a6
commit 9ea900de49

View File

@ -63,31 +63,31 @@ def initialize_torch_distributed():
if WORLD_SIZE == 1: if WORLD_SIZE == 1:
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if not torch.distributed.is_initialized(): if os.getenv("DEBUG", None) == "1":
# Call the init process. return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.distributed.init_process_group( if not torch.distributed.is_initialized():
backend="ccl", # Call the init process.
world_size=WORLD_SIZE, if SYSTEM == "ipex":
rank=RANK, import intel_extension_for_pytorch as ipex
timeout=timedelta(seconds=60),
pg_options=options, ipex.distributed.init_process_group(
) backend="ccl",
else: world_size=WORLD_SIZE,
torch.distributed.init_process_group( rank=RANK,
backend=backend, timeout=timedelta(seconds=60),
world_size=WORLD_SIZE, pg_options=options,
rank=RANK, )
timeout=timedelta(seconds=60),
pg_options=options,
)
else: else:
logger.warning("torch.distributed is already initialized.") torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
logger.warning("torch.distributed is already initialized.")
return torch.distributed.group.WORLD, RANK, WORLD_SIZE return torch.distributed.group.WORLD, RANK, WORLD_SIZE