fix crash in torch2.6 if TP=1 (#2885)

error like "ValueError: Expecting a ProcessGroup, but got a <class
'text_generation_server.utils.dist.FakeGroup'>. rank=0"

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-01-13 18:11:31 +08:00 committed by GitHub
parent 2e22164d4a
commit 1660154ae6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import os
import torch
from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
@ -18,10 +18,11 @@ class FakeBarrier:
pass
class FakeGroup:
class FakeGroup(ProcessGroup):
def __init__(self, rank, size):
self._rank = rank
self._size = size
super().__init__(rank, size)
def allreduce(self, *args, **kwargs):
return FakeBarrier()