ipex distributed ops support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-06-18 07:12:32 -07:00
parent 397731b272
commit 1dc1d5b3c5
4 changed files with 65 additions and 31 deletions

View File

@ -61,7 +61,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server
COPY proto proto
@ -101,7 +101,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
make \
g++ \
git \
wget
wget \
cmake
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
@ -125,21 +126,32 @@ RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install intel_extension_for_pytorch-2.3.100+git0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install oneccl_bind_pt==2.3.0 -f https://developer.intel.com/ipex-whl-stable-cpu
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
RUN conda install -c conda-forge gperftools
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
# Install server
COPY proto proto

View File

@ -3,6 +3,10 @@ from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import IPEX_AVAIL
if IPEX_AVAIL:
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module):
@ -96,7 +100,11 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
if IPEX_AVAIL:
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
@ -109,6 +117,9 @@ class TensorParallelHead(SuperLayer):
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
if IPEX_AVAIL:
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
@ -206,6 +217,9 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
if IPEX_AVAIL:
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
if IPEX_AVAIL:
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out

View File

@ -3,6 +3,7 @@ import torch
from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import IPEX_AVAIL
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
@ -57,13 +58,6 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None
@ -75,6 +69,17 @@ def initialize_torch_distributed():
if not torch.distributed.is_initialized():
# Call the init process.
if IPEX_AVAIL:
import intel_extension_for_pytorch as ipex
ipex.distributed.init_process_group(
backend="ccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,

View File

@ -1,6 +1,5 @@
import torch
from loguru import logger
from text_generation_server.utils.dist import WORLD_SIZE
def is_ipex_available():
@ -26,6 +25,7 @@ def get_xpu_free_memory(device, memory_fraction):
def get_cpu_free_memory(device, memory_fraction):
import psutil
from text_generation_server.utils.dist import WORLD_SIZE
mem = psutil.virtual_memory()
free_memory = int(mem.available * 0.95 / WORLD_SIZE)