mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
fix(server): Do not init process group if already initialized (#388)
This commit is contained in:
parent
aefde28b45
commit
ae466a8736
@ -2,6 +2,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class FakeBarrier:
|
class FakeBarrier:
|
||||||
@ -59,13 +60,17 @@ def initialize_torch_distributed():
|
|||||||
else:
|
else:
|
||||||
if os.getenv("DEBUG", None) == "1":
|
if os.getenv("DEBUG", None) == "1":
|
||||||
return FakeGroup(rank, world_size), rank, world_size
|
return FakeGroup(rank, world_size), rank, world_size
|
||||||
# Call the init process.
|
|
||||||
torch.distributed.init_process_group(
|
if not torch.distributed.is_initialized():
|
||||||
backend=backend,
|
# Call the init process.
|
||||||
world_size=world_size,
|
torch.distributed.init_process_group(
|
||||||
rank=rank,
|
backend=backend,
|
||||||
timeout=timedelta(seconds=60),
|
world_size=world_size,
|
||||||
pg_options=options,
|
rank=rank,
|
||||||
)
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
return torch.distributed.group.WORLD, rank, world_size
|
||||||
|
Loading…
Reference in New Issue
Block a user