mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge f147f10ed4
into 0627983c17
This commit is contained in:
commit
97bbd136de
@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
|
|||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
|
|
||||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
|
FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
@ -99,7 +99,8 @@ ENV HF_HOME=/data \
|
|||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
|
#RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
|
||||||
|
RUN pip install --pre torch==2.8.0.dev20250526+xpu torchvision==0.22.0.dev20250526+xpu --index-url https://download.pytorch.org/whl/nightly/xpu
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
@ -117,8 +118,7 @@ ENV TORCH_LLM_ALLREDUCE=1
|
|||||||
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||||
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
#RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
|
@ -90,7 +90,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex" and gather_input.device.type == "cpu":
|
||||||
ipex.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
@ -107,7 +107,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex" and output.device.type == "cpu":
|
||||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
@ -202,7 +202,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex" and out.device.type == "cpu":
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
@ -242,7 +242,7 @@ class TensorParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex" and out.device.type == "cpu":
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
@ -79,14 +79,23 @@ def initialize_torch_distributed():
|
|||||||
), "Each process is one xpu"
|
), "Each process is one xpu"
|
||||||
device = RANK % torch.xpu.device_count()
|
device = RANK % torch.xpu.device_count()
|
||||||
torch.xpu.set_device(device)
|
torch.xpu.set_device(device)
|
||||||
|
device_id = torch.device(f"xpu:{RANK}")
|
||||||
ipex.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="ccl",
|
backend="xccl",
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=120),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
device_id=device_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ipex.distributed.init_process_group(
|
||||||
|
backend="ccl",
|
||||||
|
world_size=WORLD_SIZE,
|
||||||
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=120),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
device = torch.device(f"cuda:{RANK}")
|
device = torch.device(f"cuda:{RANK}")
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
|
Loading…
Reference in New Issue
Block a user