simplify initialize_torch_distributed()

This commit is contained in:
ur4t 2024-07-04 14:50:12 +08:00
parent 5ad41aa2a6
commit 9ea900de49

View File

@ -63,31 +63,31 @@ def initialize_torch_distributed():
if WORLD_SIZE == 1:
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if not torch.distributed.is_initialized():
# Call the init process.
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
if os.getenv("DEBUG", None) == "1":
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
ipex.distributed.init_process_group(
backend="ccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
if not torch.distributed.is_initialized():
# Call the init process.
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.distributed.init_process_group(
backend="ccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
logger.warning("torch.distributed is already initialized.")
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
logger.warning("torch.distributed is already initialized.")
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
return torch.distributed.group.WORLD, RANK, WORLD_SIZE