From 1660154ae656e18244261df6244c4daf15e5a38a Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 13 Jan 2025 18:11:31 +0800 Subject: [PATCH] fix crash in torch2.6 if TP=1 (#2885) error like "ValueError: Expecting a ProcessGroup, but got a . rank=0" Signed-off-by: Wang, Yi A --- server/text_generation_server/utils/dist.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 82aeba6ce..1b766ddf7 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -1,6 +1,6 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -18,10 +18,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier()