diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 9785493e..9e83133f 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -2,6 +2,7 @@ import os import torch from datetime import timedelta +from loguru import logger def initialize_torch_distributed(): @@ -23,13 +24,16 @@ def initialize_torch_distributed(): backend = "gloo" options = None - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if not torch.distributed.is_initialized(): + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + logger.warning("torch.distributed is already initialized.") return torch.distributed.group.WORLD, rank, world_size