mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
ipex distributed ops support
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
397731b272
commit
1dc1d5b3c5
@ -61,7 +61,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
@ -101,7 +101,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
make \
|
make \
|
||||||
g++ \
|
g++ \
|
||||||
git \
|
git \
|
||||||
wget
|
wget \
|
||||||
|
cmake
|
||||||
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
@ -125,21 +126,32 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
RUN pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||||
RUN pip install intel_extension_for_pytorch-2.3.100+git0eb3473-cp310-cp310-linux_x86_64.whl
|
|
||||||
RUN pip install oneccl_bind_pt==2.3.0 -f https://developer.intel.com/ipex-whl-stable-cpu
|
|
||||||
|
|
||||||
RUN conda install -c conda-forge gperftools
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
|
ENV KMP_BLOCKTIME=1
|
||||||
|
ENV KMP_TPAUSE=0
|
||||||
|
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -3,6 +3,10 @@ from torch.nn import functional as F
|
|||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
|
if IPEX_AVAIL:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
class LayerConcat(torch.nn.Module):
|
class LayerConcat(torch.nn.Module):
|
||||||
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
|
|||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
if IPEX_AVAIL:
|
||||||
torch.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
|
|||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
|
|
||||||
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
@ -57,14 +58,7 @@ def initialize_torch_distributed():
|
|||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
else:
|
else:
|
||||||
try:
|
backend = "gloo"
|
||||||
import oneccl_bindings_for_pytorch
|
|
||||||
|
|
||||||
backend = "ccl"
|
|
||||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
|
||||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
|
||||||
except ImportError:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
@ -75,13 +69,24 @@ def initialize_torch_distributed():
|
|||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
if IPEX_AVAIL:
|
||||||
backend=backend,
|
import intel_extension_for_pytorch as ipex
|
||||||
world_size=WORLD_SIZE,
|
|
||||||
rank=RANK,
|
ipex.distributed.init_process_group(
|
||||||
timeout=timedelta(seconds=60),
|
backend="ccl",
|
||||||
pg_options=options,
|
world_size=WORLD_SIZE,
|
||||||
)
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=WORLD_SIZE,
|
||||||
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.dist import WORLD_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
def is_ipex_available():
|
def is_ipex_available():
|
||||||
@ -26,6 +25,7 @@ def get_xpu_free_memory(device, memory_fraction):
|
|||||||
|
|
||||||
def get_cpu_free_memory(device, memory_fraction):
|
def get_cpu_free_memory(device, memory_fraction):
|
||||||
import psutil
|
import psutil
|
||||||
|
from text_generation_server.utils.dist import WORLD_SIZE
|
||||||
|
|
||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
||||||
|
Loading…
Reference in New Issue
Block a user