mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Forgot a few places.
This commit is contained in:
parent
1ca91a2ff5
commit
6683e8419a
@ -5,7 +5,7 @@ from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM in {"xpu", "cpu_ipex"}:
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer):
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
if SYSTEM in {"xpu", "cpu_ipex"}:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer):
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
if SYSTEM in {"xpu", "cpu_ipex"}:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1 and reduce:
|
||||
if SYSTEM in {"xpu", "cpu_ipex"}:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
Loading…
Reference in New Issue
Block a user