diff --git a/Dockerfile_amd b/Dockerfile_amd index 1bce3eea..c3228ace 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -47,27 +47,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins make \ libssl-dev \ g++ \ - wget \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev && \ rm -rf /var/lib/apt/lists/* -# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. -RUN wget \ - https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ - && mkdir .conda \ - && bash Miniconda3-latest-Linux-x86_64.sh -b \ - && rm -f Miniconda3-latest-Linux-x86_64.sh - -ENV PATH="/root/miniconda3/bin:${PATH}" -ARG PATH="/root/user/miniconda3/bin:${PATH}" -RUN conda init bash - +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 ARG PYTORCH_VERSION='2.2.0.dev0' ARG ROCM_VERSION='5.7' -ARG PYTHON_VERSION='3.11.5' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh # Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 @@ -106,13 +112,13 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 # Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from flash attention v2 builder -COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/server/Makefile-vllm b/server/Makefile-vllm index dce84daa..8e5c2210 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -10,7 +10,7 @@ build-vllm-rocm: build-vllm vllm: # Clone vllm - pip install -U ninja --no-cache-dir + pip install -U ninja packaging --no-cache-dir git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm build-vllm: vllm diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index b0f7394c..aca95e11 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -90,6 +90,9 @@ def attention( None, ) elif HAS_FLASH_ATTN_V2_ROCM: + if window_size_left != -1: + raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).") + # RoCm flash API does not take the window_size_left and window_size_right arguments. return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 72945b56..428c9f3e 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,16 +1,4 @@ -import subprocess +import torch -IS_CUDA_SYSTEM = False -IS_ROCM_SYSTEM = False - -try: - subprocess.check_output("nvidia-smi") - IS_CUDA_SYSTEM = True -except Exception: - IS_CUDA_SYSTEM = False - -try: - subprocess.check_output("rocm-smi") - IS_ROCM_SYSTEM = True -except Exception: - IS_ROCM_SYSTEM = False \ No newline at end of file +IS_ROCM_SYSTEM = torch.version.hip is not None +IS_CUDA_SYSTEM = torch.version.cuda is not None