From e72d2574c8e2b47a9520c79558cbfadfd600061a Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 26 May 2025 20:22:03 -0700 Subject: [PATCH 1/2] use xccl Signed-off-by: Wang, Yi A --- Dockerfile_intel | 6 ++--- .../layers/tensor_parallel.py | 8 +++--- server/text_generation_server/utils/dist.py | 25 +++++++++++++------ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 3bc04332..e9a16d22 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for Intel -FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu +FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu USER root @@ -99,7 +99,8 @@ ENV HF_HOME=/data \ WORKDIR /usr/src -RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu +#RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu +RUN pip install --pre torch==2.8.0.dev20250526+xpu torchvision==0.22.0.dev20250526+xpu --index-url https://download.pytorch.org/whl/nightly/xpu # Install server COPY proto proto @@ -117,7 +118,6 @@ ENV TORCH_LLM_ALLREDUCE=1 ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1..e020178a 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -90,7 +90,7 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if SYSTEM == "ipex": + if SYSTEM == "ipex" and gather_input.device.type == "cpu": ipex.distributed.all_gather_into_tensor( world_out, gather_input, group=self.process_group ) @@ -107,7 +107,7 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if SYSTEM == "ipex": + if SYSTEM == "ipex" and output.device.type == "cpu": ipex.distributed.all_gather(world_output, output, group=self.process_group) else: torch.distributed.all_gather(world_output, output, group=self.process_group) @@ -202,7 +202,7 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if SYSTEM == "ipex": + if SYSTEM == "ipex" and out.device.type == "cpu": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) @@ -242,7 +242,7 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if SYSTEM == "ipex": + if SYSTEM == "ipex" and out.device.type == "cpu": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 4a1bef6d..2a4d72b8 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -79,14 +79,23 @@ def initialize_torch_distributed(): ), "Each process is one xpu" device = RANK % torch.xpu.device_count() torch.xpu.set_device(device) - - ipex.distributed.init_process_group( - backend="ccl", - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=120), - pg_options=options, - ) + device_id = torch.device(f"xpu:{RANK}") + torch.distributed.init_process_group( + backend="xccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=120), + pg_options=options, + device_id=device_id, + ) + else: + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=120), + pg_options=options, + ) else: device = torch.device(f"cuda:{RANK}") torch.distributed.init_process_group( From f147f10ed41bc13ad363fe52157f904f334ba2ee Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 27 May 2025 22:38:02 -0700 Subject: [PATCH 2/2] remove install of ipex Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index e9a16d22..7714627c 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -118,7 +118,7 @@ ENV TORCH_LLM_ALLREDUCE=1 ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl +#RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router