diff --git a/Dockerfile b/Dockerfile index 25616f29..78870f49 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,3 +1,4 @@ +# Rust builder FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef WORKDIR /usr/src @@ -27,49 +28,117 @@ COPY router router COPY launcher launcher RUN cargo build --release -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as base +# CUDA kernel builder +# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as kernel-builder +ARG PYTORCH_VERSION=2.0.0 +ARG PYTHON_VERSION=3.9 +ARG CUDA_VERSION=11.8 ARG MAMBA_VERSION=23.1.0-1 +ARG CUDA_CHANNEL=nvidia +ARG INSTALL_CHANNEL=pytorch +# Automatically set by buildx +ARG TARGETPLATFORM -ENV LANG=C.UTF-8 \ - LC_ALL=C.UTF-8 \ - DEBIAN_FRONTEND=noninteractive \ - HUGGINGFACE_HUB_CACHE=/data \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + ninja-build \ + cmake \ + curl \ + git && \ + rm -rf /var/lib/apt/lists/* +RUN /usr/sbin/update-ccache-symlinks && \ + mkdir /opt/ccache && \ + ccache --set-config=cache_dir=/opt/ccache +ENV PATH /opt/conda/bin:$PATH + +# Install conda +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +# Install pytorch +# On arm64 we exit with an error code +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +# Build Flash Attention CUDA kernels +FROM kernel-builder as flash-att-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att Makefile + +# Build specific version of flash attention +RUN make build-flash-attention + +# Build Transformers CUDA kernels +FROM kernel-builder as transformers-builder + +WORKDIR /usr/src + +COPY server/Makefile-transformers Makefile + +# Build specific version of transformers +RUN BUILD_EXTENSIONS="True" make build-transformers + +# Text Generation Inference base image +FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base + +# Conda env +ENV PATH=/opt/conda/bin:$PATH + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ MODEL_ID=bigscience/bloom-560m \ QUANTIZE=false \ NUM_SHARD=1 \ - PORT=80 \ - CUDA_HOME=/usr/local/cuda \ - LD_LIBRARY_PATH="/opt/conda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \ - PATH=$PATH:/opt/conda/bin:/usr/local/cuda/bin + PORT=80 -RUN apt-get update && apt-get install -y git curl libssl-dev ninja-build && rm -rf /var/lib/apt/lists/* - -RUN cd ~ && \ - curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" \ - chmod +x ~/mambaforge.sh && \ - bash ~/mambaforge.sh -b -p /opt/conda && \ - rm ~/mambaforge.sh +LABEL com.nvidia.volumes.needed="nvidia_driver" WORKDIR /usr/src -# Install torch -RUN pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + libssl-dev \ + ca-certificates \ + make \ + && rm -rf /var/lib/apt/lists/* -# Install specific version of flash attention -COPY server/Makefile-flash-att server/Makefile -RUN cd server && make install-flash-attention +# Copy conda with PyTorch installed +COPY --from=kernel-builder /opt/conda /opt/conda -# Install specific version of transformers -COPY server/Makefile-transformers server/Makefile -RUN cd server && BUILD_EXTENSIONS="True" make install-transformers +# Copy build artifacts from flash attention builder +COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages -COPY server/Makefile server/Makefile +# Copy build artifacts from transformers builder +COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers +COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers + +# Install transformers dependencies +RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir # Install server COPY proto proto COPY server server +COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install ".[bnb]" --no-cache-dir