use mamba

This commit is contained in:
OlivierDehaene 2023-04-13 16:48:21 +02:00
parent 4cfef0441f
commit f1ddbf5c72
2 changed files with 12 additions and 23 deletions

View File

@ -91,9 +91,6 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
build-args: |
KERNEL_BUILDER_IMAGE=registry.internal.huggingface.tech/pytorch-base-images/kernel-builder:2.0.0-cuda11.8
PYTORCH_IMAGE=registry.internal.huggingface.tech/pytorch-base-images/torch:2.0.0-cuda11.8
file: Dockerfile
push: ${{ github.event_name != 'pull_request' }}
platforms: 'linux/amd64'

View File

@ -1,8 +1,3 @@
# allow using other images to build kernels
ARG KERNEL_BUILDER_IMAGE=kernel-builder
# Allow using other images as pytorch base image
ARG PYTORCH_IMAGE=pytorch-install
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
WORKDIR /usr/src
@ -40,6 +35,7 @@ FROM ubuntu:22.04 as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
@ -59,16 +55,15 @@ RUN /usr/sbin/update-ccache-symlinks && \
ENV PATH /opt/conda/bin:$PATH
# Install conda
# translating Docker's TARGETPLATFORM into miniconda arches
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
*) MINICONDA_ARCH=x86_64 ;; \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
# Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431
RUN chmod +x ~/miniconda.sh && \
bash ~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
# Install pytorch
# On arm64 we exit with an error code
@ -80,7 +75,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya
# CUDA kernels builder image
FROM $PYTORCH_IMAGE as kernel-builder
FROM pytorch-install as kernel-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \
@ -91,7 +86,7 @@ RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
# Build Flash Attention CUDA kernels
FROM $KERNEL_BUILDER_IMAGE as flash-att-builder
FROM kernel-builder as flash-att-builder
WORKDIR /usr/src
@ -101,7 +96,7 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention
# Build Transformers CUDA kernels
FROM $KERNEL_BUILDER_IMAGE as transformers-builder
FROM kernel-builder as transformers-builder
WORKDIR /usr/src
@ -110,9 +105,6 @@ COPY server/Makefile-transformers Makefile
# Build specific version of transformers
RUN BUILD_EXTENSIONS="True" make build-transformers
# re-export because `COPY --from` does not support ARG vars directly
FROM $PYTORCH_IMAGE as pytorch
# Text Generation Inference base image
FROM ubuntu:22.04 as base
@ -144,7 +136,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
&& rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed
COPY --from=pytorch /opt/conda /opt/conda
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages