Forgot a few places.

This commit is contained in:
Nicolas Patry 2024-06-25 10:43:38 +00:00
parent 1ca91a2ff5
commit 6683e8419a

View File

@ -5,7 +5,7 @@ from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM in {"xpu", "cpu_ipex"}:
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
if SYSTEM in {"xpu", "cpu_ipex"}:
if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer):
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
if SYSTEM in {"xpu", "cpu_ipex"}:
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
if SYSTEM in {"xpu", "cpu_ipex"}:
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)