mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix crash in torch2.6 if TP=1 (#2885)
error like "ValueError: Expecting a ProcessGroup, but got a <class 'text_generation_server.utils.dist.FakeGroup'>. rank=0" Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
2e22164d4a
commit
1660154ae6
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
@ -18,10 +18,11 @@ class FakeBarrier:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FakeGroup:
|
class FakeGroup(ProcessGroup):
|
||||||
def __init__(self, rank, size):
|
def __init__(self, rank, size):
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._size = size
|
self._size = size
|
||||||
|
super().__init__(rank, size)
|
||||||
|
|
||||||
def allreduce(self, *args, **kwargs):
|
def allreduce(self, *args, **kwargs):
|
||||||
return FakeBarrier()
|
return FakeBarrier()
|
||||||
|
Loading…
Reference in New Issue
Block a user