diff --git a/Dockerfile_intel b/Dockerfile_intel index 30c3eb6d..0bcb3585 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -61,7 +61,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server COPY proto proto @@ -101,7 +101,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins make \ g++ \ git \ - wget + wget \ + cmake ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ @@ -125,21 +126,32 @@ RUN chmod +x ~/mambaforge.sh && \ bash ~/mambaforge.sh -b -p /opt/conda && \ rm ~/mambaforge.sh -RUN pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu +RUN conda install -c conda-forge gperftools mkl + +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl WORKDIR /usr/src -RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl -RUN pip install intel_extension_for_pytorch-2.3.100+git0eb3473-cp310-cp310-linux_x86_64.whl -RUN pip install oneccl_bind_pt==2.3.0 -f https://developer.intel.com/ipex-whl-stable-cpu +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a -RUN conda install -c conda-forge gperftools +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 -ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so +RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install + +RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib +ENV KMP_BLOCKTIME=1 +ENV KMP_TPAUSE=0 +ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist +ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist +ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist # Install server COPY proto proto diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 6005f737..510dc2c6 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,6 +3,10 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.utils.import_utils import IPEX_AVAIL + +if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex class LayerConcat(torch.nn.Module): @@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if IPEX_AVAIL: + ipex.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - torch.distributed.all_gather(world_output, output, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_gather(world_output, output, group=self.process_group) + else: + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 3625e6f2..7d387563 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from datetime import timedelta from loguru import logger +from text_generation_server.utils.import_utils import IPEX_AVAIL # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -57,14 +58,7 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - try: - import oneccl_bindings_for_pytorch - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" + backend = "gloo" options = None if WORLD_SIZE == 1: @@ -75,13 +69,24 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex + + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: logger.warning("torch.distributed is already initialized.") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index f1fd8f79..bb6821c8 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,5 @@ import torch from loguru import logger -from text_generation_server.utils.dist import WORLD_SIZE def is_ipex_available(): @@ -26,6 +25,7 @@ def get_xpu_free_memory(device, memory_fraction): def get_cpu_free_memory(device, memory_fraction): import psutil + from text_generation_server.utils.dist import WORLD_SIZE mem = psutil.virtual_memory() free_memory = int(mem.available * 0.95 / WORLD_SIZE)