diff --git a/Dockerfile_amd b/Dockerfile_amd index 766881a8..388cca83 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -88,7 +88,8 @@ RUN case ${TARGETPLATFORM} in \ esac && \ curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" RUN chmod +x ~/mambaforge.sh && \ - bash ~/mambaforge.sh -b -p /opt/conda && \ + rm -rf /opt/conda && \ + bash ~/mambaforge.sh -u -b -p /opt/conda && \ mamba init && \ rm ~/mambaforge.sh @@ -103,59 +104,139 @@ RUN case ${TARGETPLATFORM} in \ /opt/conda/bin/conda clean -ya # Install flash-attention, torch dependencies -RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir +RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* + +RUN conda install mkl=2021 +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ + + +ARG COMMON_WORKDIR=/ +WORKDIR ${COMMON_WORKDIR} + # Install HIPBLASLt -ARG HIPBLASLT_BRANCH="6f65c6e" -RUN git clone https://github.com/ROCm/hipBLASLt \ +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH="e6da924" +RUN git clone https://github.com/ROCm/hipBLASLt.git \ && cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ - && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ && cd build/release \ && make package -RUN dpkg -i hipBLASLt/build/release/*.deb \ - && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ - && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \ - && rm -rf hipBLASLt +# RUN dpkg -i hipBLASLt/build/release/*.deb \ +# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ +# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \ +# && rm -rf hipBLASLt -RUN pip uninstall -y triton && \ - git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ - cd triton/python && \ - pip install . && \ - rm -r triton +FROM scratch AS export_hipblaslt_1 +ARG COMMON_WORKDIR +COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / +FROM export_hipblaslt_1 AS export_hipblaslt -ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" -RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ - cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \ - git checkout ${PYTORCH_COMMIT} && \ - git submodule update --init --recursive && \ - pip install -r requirements.txt --no-cache-dir +# RCCL build stages +FROM base AS build_rccl +ARG RCCL_BRANCH="rocm-6.2.0" +RUN git clone https://github.com/ROCm/rccl \ + && cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +FROM scratch AS export_rccl_1 +ARG COMMON_WORKDIR +COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / +FROM scratch AS export_rccl_0 +FROM export_rccl_1 AS export_rccl -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" +# Triton build stages +FROM base AS build_triton +ARG TRITON_BRANCH="e192dba" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_triton_1 +ARG COMMON_WORKDIR +COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / +FROM export_triton_1 AS export_triton -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install -RUN rm -rf pytorch +# # AMD-SMI build stages +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +FROM scratch AS export_amdsmi +COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / -# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm -ENV HIP_FORCE_DEV_KERNARG=1 -# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. -# However, Triton requires a tunning for each prompt length, which is prohibitive. -ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +FROM base as build_pytorch -FROM base AS kernel-builder +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11 +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +# A commit to fix the output scaling factor issue in _scaled_mm +# Not yet in 2.5.0-rc1 +ARG PYTORCH_BRANCH="cedc116" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" + +RUN git clone ${PYTORCH_REPO} pytorch \ + && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ + && pip install -r requirements.txt --no-cache-dir \ + && python tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch as export_pytorch_1 +ARG COMMON_WORKDIR +COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / +FROM export_pytorch_1 as export_pytorch + +FROM base AS install_deps + +ARG COMMON_WORKDIR + +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + # RCCL needs to be installed twice + && dpkg -i /install/*.deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_triton,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y triton \ + && pip install /install/*.whl; \ + fi + +RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y amdsmi \ + && pip install /install/*.whl; + +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ + fi + +FROM install_deps AS kernel-builder # # Build vllm kernels FROM kernel-builder AS vllm-builder @@ -195,7 +276,7 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build -FROM base AS base-copy +FROM install_deps AS base-copy # Text Generation Inference base env ENV HF_HOME=/data \ @@ -245,6 +326,12 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. +# However, Triton requires a tunning for each prompt length, which is prohibitive. +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV VLLM_MOE_PADDING=0 diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 74293d9c..a9cdf782 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28) +flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: pip install -U packaging wheel @@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda build-flash-attention-v2-rocm: if [ ! -d 'flash-attention-v2' ]; then \ pip install -U packaging ninja --no-cache-dir && \ - git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \ cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ fi diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py index 987d181e..cc307bf0 100644 --- a/server/exllama_kernels/setup.py +++ b/server/exllama_kernels/setup.py @@ -1,5 +1,17 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = [] +extra_cflags = [] +if torch.version.hip: + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + +extra_compile_args = { + "cxx": extra_cflags, + "nvcc": extra_cuda_cflags, +} setup( name="exllama_kernels", @@ -13,6 +25,7 @@ setup( "exllama_kernels/cuda_func/q4_matmul.cu", "exllama_kernels/cuda_func/q4_matrix.cu", ], + extra_compile_args=extra_compile_args, ) ], cmdclass={"build_ext": BuildExtension}, diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index 4a16b546..56ffa973 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch extra_cuda_cflags = ["-lineinfo", "-O3"] - +extra_cflags = [] if torch.version.hip: - extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"] extra_compile_args = { + "cxx": extra_cflags, "nvcc": extra_cuda_cflags, }