mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
update on review
This commit is contained in:
parent
6353a87f4c
commit
80ce8910f1
@ -47,27 +47,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
make \
|
||||
libssl-dev \
|
||||
g++ \
|
||||
wget \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
RUN wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir .conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PATH="/root/user/miniconda3/bin:${PATH}"
|
||||
RUN conda init bash
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||
ARG ROCM_VERSION='5.7'
|
||||
ARG PYTHON_VERSION='3.11.5'
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||
@ -106,13 +112,13 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
PORT=80
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
@ -10,7 +10,7 @@ build-vllm-rocm: build-vllm
|
||||
|
||||
vllm:
|
||||
# Clone vllm
|
||||
pip install -U ninja --no-cache-dir
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||
|
||||
build-vllm: vllm
|
||||
|
@ -90,6 +90,9 @@ def attention(
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||
if window_size_left != -1:
|
||||
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
|
||||
|
||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
|
@ -1,16 +1,4 @@
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
IS_CUDA_SYSTEM = False
|
||||
IS_ROCM_SYSTEM = False
|
||||
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
IS_CUDA_SYSTEM = True
|
||||
except Exception:
|
||||
IS_CUDA_SYSTEM = False
|
||||
|
||||
try:
|
||||
subprocess.check_output("rocm-smi")
|
||||
IS_ROCM_SYSTEM = True
|
||||
except Exception:
|
||||
IS_ROCM_SYSTEM = False
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
|
Loading…
Reference in New Issue
Block a user