This commit is contained in:
OlivierDehaene 2023-11-20 19:03:32 +01:00
parent 2e75027187
commit abb498a907
5 changed files with 843 additions and 631 deletions

View File

@ -37,10 +37,10 @@ RUN cargo build --release
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM debian:bullseye-slim as pytorch-install FROM debian:bullseye-slim as conda-install
ARG PYTORCH_VERSION=2.1.0 ARG PYTORCH_VERSION=2.1.0
ARG PYTHON_VERSION=3.9 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
@ -70,23 +70,35 @@ RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \ bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh rm ~/mambaforge.sh
FROM conda-install as pytorch-install
# Install pytorch # Install pytorch
# On arm64 we exit with an error code # On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \ RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \ "linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \ *) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=${CUDA_VERSION}" ;; \
esac && \ esac && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# CUDA kernels builder image # CUDA kernels builder image
FROM pytorch-install as kernel-builder FROM conda-install as kernel-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \ ninja-build \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ # FIXME: for some reason if we install cuda after, some libs are not properly linked...
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-12.1.0" cuda && \
/opt/conda/bin/conda clean -ya
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=${CUDA_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
@ -172,24 +184,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=pytorch-install /opt/conda /opt/conda COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from flash attention v2 builder # Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from exllama kernels builder # Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from awq kernels builder # Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from eetq kernels builder # Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir

View File

@ -18,7 +18,7 @@ gen-server:
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cu121 --no-cache-dir
install: gen-server install-torch install: gen-server install-torch
pip install pip --upgrade pip install pip --upgrade

1370
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -47,7 +47,7 @@ pytest = "^7.3.0"
[[tool.poetry.source]] [[tool.poetry.source]]
name = "pytorch-gpu-src" name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu118" url = "https://download.pytorch.org/whl/cu121"
priority = "explicit" priority = "explicit"
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@ -1,28 +1,28 @@
accelerate==0.20.3 ; python_version >= "3.9" and python_version < "3.13" accelerate==0.20.3 ; python_version >= "3.9" and python_version < "3.13"
aiohttp==3.8.6 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.9.0 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13" bitsandbytes==0.41.2.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
datasets==2.14.4 ; python_version >= "3.9" and python_version < "3.13" datasets==2.14.7 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13" dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13" frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.2 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2023.9.2 ; python_version >= "3.9" and python_version < "3.13" fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.59.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.59.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.59.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
@ -31,8 +31,20 @@ markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13" multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2 ; python_version >= "3.9" and python_version < "3.13" networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.1 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nccl-cu12==2.18.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvjitlink-cu12==12.3.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -43,33 +55,35 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.1.1 ; python_version >= "3.9" and python_version < "3.13" pandas==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13" peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.4 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
psutil==5.9.6 ; python_version >= "3.9" and python_version < "3.13" psutil==5.9.6 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==13.0.0 ; python_version >= "3.9" and python_version < "3.13" pyarrow-hotfix==0.5 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==14.0.1 ; python_version >= "3.9" and python_version < "3.13"
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.13" python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.13"
pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "3.13" pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.2 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.2 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
texttable==1.7.0 ; python_version >= "3.9" and python_version < "3.13" texttable==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13" torch==2.1.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
triton==2.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.13" tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.7 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"