From 9ea900de49b523260bf1b6a06bbef679b4869592 Mon Sep 17 00:00:00 2001 From: ur4t <46435411+ur4t@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:50:12 +0800 Subject: [PATCH] simplify initialize_torch_distributed() --- server/text_generation_server/utils/dist.py | 48 ++++++++++----------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86..e17f56f0 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -63,31 +63,31 @@ def initialize_torch_distributed(): if WORLD_SIZE == 1: return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE - else: - if os.getenv("DEBUG", None) == "1": - return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE - if not torch.distributed.is_initialized(): - # Call the init process. - if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex + if os.getenv("DEBUG", None) == "1": + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE - ipex.distributed.init_process_group( - backend="ccl", - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) - else: - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if not torch.distributed.is_initialized(): + # Call the init process. + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: - logger.warning("torch.distributed is already initialized.") + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + logger.warning("torch.distributed is already initialized.") - return torch.distributed.group.WORLD, RANK, WORLD_SIZE + return torch.distributed.group.WORLD, RANK, WORLD_SIZE