update on review

This commit is contained in:
Felix Marty 2023-11-08 11:01:10 +00:00
parent 6353a87f4c
commit 80ce8910f1
4 changed files with 29 additions and 32 deletions

View File

@ -47,27 +47,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
make \
libssl-dev \
g++ \
wget \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev && \
rm -rf /var/lib/apt/lists/*
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir .conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PATH="/root/user/miniconda3/bin:${PATH}"
RUN conda init bash
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7'
ARG PYTHON_VERSION='3.11.5'
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
mamba init && \
rm ~/mambaforge.sh
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
@ -106,13 +112,13 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -10,7 +10,7 @@ build-vllm-rocm: build-vllm
vllm:
# Clone vllm
pip install -U ninja --no-cache-dir
pip install -U ninja packaging --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
build-vllm: vllm

View File

@ -90,6 +90,9 @@ def attention(
None,
)
elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1:
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,

View File

@ -1,16 +1,4 @@
import subprocess
import torch
IS_CUDA_SYSTEM = False
IS_ROCM_SYSTEM = False
try:
subprocess.check_output("nvidia-smi")
IS_CUDA_SYSTEM = True
except Exception:
IS_CUDA_SYSTEM = False
try:
subprocess.check_output("rocm-smi")
IS_ROCM_SYSTEM = True
except Exception:
IS_ROCM_SYSTEM = False
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None