mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
improve dockerfile
This commit is contained in:
parent
7cb49f6f4f
commit
3b28cf9067
171
Dockerfile_amd
171
Dockerfile_amd
@ -88,7 +88,8 @@ RUN case ${TARGETPLATFORM} in \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm -rf /opt/conda && \
|
||||
bash ~/mambaforge.sh -u -b -p /opt/conda && \
|
||||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
@ -103,59 +104,139 @@ RUN case ${TARGETPLATFORM} in \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir
|
||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN conda install mkl=2021
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||
|
||||
|
||||
ARG COMMON_WORKDIR=/
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
# Install HIPBLASLt
|
||||
ARG HIPBLASLT_BRANCH="6f65c6e"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt \
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH="e6da924"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN dpkg -i hipBLASLt/build/release/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
|
||||
&& rm -rf hipBLASLt
|
||||
# RUN dpkg -i hipBLASLt/build/release/*.deb \
|
||||
# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
|
||||
# && rm -rf hipBLASLt
|
||||
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install . && \
|
||||
rm -r triton
|
||||
FROM scratch AS export_hipblaslt_1
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||
FROM export_hipblaslt_1 AS export_hipblaslt
|
||||
|
||||
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
|
||||
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
|
||||
git checkout ${PYTORCH_COMMIT} && \
|
||||
git submodule update --init --recursive && \
|
||||
pip install -r requirements.txt --no-cache-dir
|
||||
# RCCL build stages
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||
RUN git clone https://github.com/ROCm/rccl \
|
||||
&& cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
FROM scratch AS export_rccl_1
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||
FROM scratch AS export_rccl_0
|
||||
FROM export_rccl_1 AS export_rccl
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
USE_ROCM="1" \
|
||||
BUILD_TEST="0" \
|
||||
USE_FBGEMM="0" \
|
||||
USE_NNPACK="0" \
|
||||
USE_QNNPACK="0" \
|
||||
USE_XNNPACK="0" \
|
||||
USE_FLASH_ATTENTION="1" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
# Triton build stages
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||
&& cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_triton_1
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||
FROM export_triton_1 AS export_triton
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
RUN rm -rf pytorch
|
||||
# # AMD-SMI build stages
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
FROM scratch AS export_amdsmi
|
||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
FROM base as build_pytorch
|
||||
|
||||
FROM base AS kernel-builder
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||
# Not yet in 2.5.0-rc1
|
||||
ARG PYTORCH_BRANCH="cedc116"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
|
||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt --no-cache-dir \
|
||||
&& python tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch as export_pytorch_1
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||
FROM export_pytorch_1 as export_pytorch
|
||||
|
||||
FROM base AS install_deps
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Install hipblaslt
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
# RCCL needs to be installed twice
|
||||
&& dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y triton \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y amdsmi \
|
||||
&& pip install /install/*.whl;
|
||||
|
||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y torch torchvision \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
FROM install_deps AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
@ -195,7 +276,7 @@ COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base AS base-copy
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -245,6 +326,12 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
|
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)
|
||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||
build-flash-attention-v2-rocm:
|
||||
if [ ! -d 'flash-attention-v2' ]; then \
|
||||
pip install -U packaging ninja --no-cache-dir && \
|
||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
||||
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||
fi
|
||||
|
@ -1,5 +1,17 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = []
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
@ -13,6 +25,7 @@ setup(
|
||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user