mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
fix crash in torch2.6 if TP=1 (#2885)
error like "ValueError: Expecting a ProcessGroup, but got a <class 'text_generation_server.utils.dist.FakeGroup'>. rank=0" Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
2e22164d4a
commit
1660154ae6
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from torch.distributed import ProcessGroup
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
@ -18,10 +18,11 @@ class FakeBarrier:
|
||||
pass
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
class FakeGroup(ProcessGroup):
|
||||
def __init__(self, rank, size):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
super().__init__(rank, size)
|
||||
|
||||
def allreduce(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
Loading…
Reference in New Issue
Block a user