diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 82aeba6ce..1b766ddf7 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -1,6 +1,6 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -18,10 +18,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier()