mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
217 lines
6.5 KiB
Plaintext
217 lines
6.5 KiB
Plaintext
# Rust builder
|
|
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
|
WORKDIR /usr/src
|
|
|
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
|
|
|
FROM chef as planner
|
|
COPY Cargo.toml Cargo.toml
|
|
COPY rust-toolchain.toml rust-toolchain.toml
|
|
COPY proto proto
|
|
COPY benchmark benchmark
|
|
COPY router router
|
|
COPY launcher launcher
|
|
RUN cargo chef prepare --recipe-path recipe.json
|
|
|
|
FROM chef AS builder
|
|
|
|
ARG GIT_SHA
|
|
ARG DOCKER_LABEL
|
|
|
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
|
rm -f $PROTOC_ZIP
|
|
|
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
|
RUN cargo chef cook --release --recipe-path recipe.json
|
|
|
|
COPY Cargo.toml Cargo.toml
|
|
COPY rust-toolchain.toml rust-toolchain.toml
|
|
COPY proto proto
|
|
COPY benchmark benchmark
|
|
COPY router router
|
|
COPY launcher launcher
|
|
RUN cargo build --release
|
|
|
|
# Text Generation Inference base image for RoCm
|
|
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
|
|
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
build-essential \
|
|
ca-certificates \
|
|
ccache \
|
|
curl \
|
|
git \
|
|
make \
|
|
libssl-dev \
|
|
g++ \
|
|
# Needed to build VLLM & flash.
|
|
rocthrust-dev \
|
|
hipsparse-dev \
|
|
hipblas-dev \
|
|
hipblaslt-dev \
|
|
rocblas-dev \
|
|
hiprand-dev \
|
|
rocrand-dev \
|
|
miopen-hip-dev \
|
|
hipfft-dev \
|
|
hipcub-dev \
|
|
hipsolver-dev \
|
|
rccl-dev \
|
|
cmake \
|
|
python3-dev && \
|
|
rm -rf /var/lib/apt/lists/*
|
|
|
|
# Keep in sync with `server/pyproject.toml
|
|
ARG MAMBA_VERSION=23.1.0-1
|
|
ARG PYTORCH_VERSION='2.3.0'
|
|
ARG ROCM_VERSION='6.0.2'
|
|
ARG PYTHON_VERSION='3.10.10'
|
|
# Automatically set by buildx
|
|
ARG TARGETPLATFORM
|
|
ENV PATH /opt/conda/bin:$PATH
|
|
|
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
|
# Install mamba
|
|
# translating Docker's TARGETPLATFORM into mamba arches
|
|
RUN case ${TARGETPLATFORM} in \
|
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
|
*) MAMBA_ARCH=x86_64 ;; \
|
|
esac && \
|
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
RUN chmod +x ~/mambaforge.sh && \
|
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
mamba init && \
|
|
rm ~/mambaforge.sh
|
|
|
|
# Install flash-attention, torch dependencies
|
|
RUN pip install numpy einops ninja --no-cache-dir
|
|
|
|
RUN conda install intel::mkl-static intel::mkl-include
|
|
RUN pip uninstall -y triton && \
|
|
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
|
cd triton/python && \
|
|
pip install .
|
|
|
|
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
|
|
|
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
|
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
|
ARG BUILD_CAFFE2="0" \
|
|
BUILD_CAFFE2_OPS="0" \
|
|
USE_CUDA="0" \
|
|
USE_ROCM="1" \
|
|
BUILD_TEST="0" \
|
|
USE_FBGEMM="0" \
|
|
USE_NNPACK="0" \
|
|
USE_QNNPACK="0" \
|
|
USE_XNNPACK="0" \
|
|
USE_FLASH_ATTENTION="1" \
|
|
USE_MEM_EFF_ATTENTION="0"
|
|
|
|
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
|
|
|
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
|
ENV HIP_FORCE_DEV_KERNARG=1
|
|
|
|
# On MI300, performances for flash with Triton FA is very competitive (actually better than CK)
|
|
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1
|
|
|
|
FROM base AS kernel-builder
|
|
|
|
# # Build vllm kernels
|
|
FROM kernel-builder AS vllm-builder
|
|
WORKDIR /usr/src
|
|
|
|
COPY server/Makefile-vllm Makefile
|
|
|
|
# Build specific version of vllm
|
|
RUN make build-vllm-rocm
|
|
|
|
# Build Flash Attention v2 kernels
|
|
FROM kernel-builder AS flash-att-v2-builder
|
|
WORKDIR /usr/src
|
|
|
|
COPY server/Makefile-flash-att-v2 Makefile
|
|
|
|
# Build specific version of flash attention v2
|
|
RUN make build-flash-attention-v2-rocm
|
|
|
|
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
|
FROM kernel-builder as custom-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/custom_kernels/ .
|
|
RUN python setup.py build
|
|
|
|
# Build exllama kernels
|
|
FROM kernel-builder as exllama-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/exllama_kernels/ .
|
|
|
|
RUN python setup.py build
|
|
|
|
# Build exllama v2 kernels
|
|
FROM kernel-builder as exllamav2-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/exllamav2_kernels/ .
|
|
|
|
RUN python setup.py build
|
|
|
|
FROM base as base-copy
|
|
|
|
# Text Generation Inference base env
|
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
|
PORT=80
|
|
|
|
# Copy builds artifacts from vllm builder
|
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from flash attention v2 builder
|
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from custom kernels builder
|
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from exllama kernels builder
|
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from exllamav2 kernels builder
|
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Install server
|
|
COPY proto proto
|
|
COPY server server
|
|
COPY server/Makefile server/Makefile
|
|
RUN cd server && \
|
|
make gen-server && \
|
|
pip install -r requirements_rocm.txt && \
|
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
|
|
|
# Install benchmarker
|
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
|
# Install router
|
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
|
# Install launcher
|
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
|
|
|
# AWS Sagemaker compatible image
|
|
FROM base as sagemaker
|
|
|
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
|
RUN chmod +x entrypoint.sh
|
|
|
|
ENTRYPOINT ["./entrypoint.sh"]
|
|
|
|
# Final image
|
|
FROM base-copy
|
|
|
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|
RUN chmod +x /tgi-entrypoint.sh
|
|
|
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
|
CMD ["--json-output"]
|