diff --git a/Dockerfile_intel b/Dockerfile_intel
index 3bc04332..7714627c 100644
--- a/Dockerfile_intel
+++ b/Dockerfile_intel
@@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
 
 # Text Generation Inference base image for Intel
 
-FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
+FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu
 
 USER root
 
@@ -99,7 +99,8 @@ ENV HF_HOME=/data \
 
 WORKDIR /usr/src
 
-RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
+#RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
+RUN pip install --pre torch==2.8.0.dev20250526+xpu torchvision==0.22.0.dev20250526+xpu --index-url https://download.pytorch.org/whl/nightly/xpu
 
 # Install server
 COPY proto proto
@@ -117,8 +118,7 @@ ENV TORCH_LLM_ALLREDUCE=1
 ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
 
-RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
-RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
+#RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
 # Install benchmarker
 COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
 # Install router
diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py
index 13f12ef1..e020178a 100644
--- a/server/text_generation_server/layers/tensor_parallel.py
+++ b/server/text_generation_server/layers/tensor_parallel.py
@@ -90,7 +90,7 @@ class TensorParallelHead(SuperLayer):
                 local_out = gather_input.T
 
             torch.mm(input, self.linear.weight.T, out=local_out)
-            if SYSTEM == "ipex":
+            if SYSTEM == "ipex" and gather_input.device.type == "cpu":
                 ipex.distributed.all_gather_into_tensor(
                     world_out, gather_input, group=self.process_group
                 )
@@ -107,7 +107,7 @@ class TensorParallelHead(SuperLayer):
         world_output = [
             torch.empty_like(output) for _ in range(self.process_group.size())
         ]
-        if SYSTEM == "ipex":
+        if SYSTEM == "ipex" and output.device.type == "cpu":
             ipex.distributed.all_gather(world_output, output, group=self.process_group)
         else:
             torch.distributed.all_gather(world_output, output, group=self.process_group)
@@ -202,7 +202,7 @@ class TensorParallelRowLinear(SuperLayer):
     def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
         out = super().forward(input)
         if self.process_group.size() > 1 and reduce:
-            if SYSTEM == "ipex":
+            if SYSTEM == "ipex" and out.device.type == "cpu":
                 ipex.distributed.all_reduce(out, group=self.process_group)
             else:
                 torch.distributed.all_reduce(out, group=self.process_group)
@@ -242,7 +242,7 @@ class TensorParallelEmbedding(torch.nn.Module):
         )
         out = torch.nn.functional.embedding(input, self.weight)
         if self.reduce and self.process_group.size() > 1:
-            if SYSTEM == "ipex":
+            if SYSTEM == "ipex" and out.device.type == "cpu":
                 ipex.distributed.all_reduce(out, group=self.process_group)
             else:
                 torch.distributed.all_reduce(out, group=self.process_group)
diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py
index 4a1bef6d..2a4d72b8 100644
--- a/server/text_generation_server/utils/dist.py
+++ b/server/text_generation_server/utils/dist.py
@@ -79,14 +79,23 @@ def initialize_torch_distributed():
                     ), "Each process is one xpu"
                     device = RANK % torch.xpu.device_count()
                     torch.xpu.set_device(device)
-
-                ipex.distributed.init_process_group(
-                    backend="ccl",
-                    world_size=WORLD_SIZE,
-                    rank=RANK,
-                    timeout=timedelta(seconds=120),
-                    pg_options=options,
-                )
+                    device_id = torch.device(f"xpu:{RANK}")
+                    torch.distributed.init_process_group(
+                        backend="xccl",
+                        world_size=WORLD_SIZE,
+                        rank=RANK,
+                        timeout=timedelta(seconds=120),
+                        pg_options=options,
+                        device_id=device_id,
+                    )
+                else:
+                    ipex.distributed.init_process_group(
+                        backend="ccl",
+                        world_size=WORLD_SIZE,
+                        rank=RANK,
+                        timeout=timedelta(seconds=120),
+                        pg_options=options,
+                    )
             else:
                 device = torch.device(f"cuda:{RANK}")
                 torch.distributed.init_process_group(