mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
update on review
This commit is contained in:
parent
6353a87f4c
commit
80ce8910f1
@ -47,27 +47,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
make \
|
make \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
g++ \
|
g++ \
|
||||||
wget \
|
|
||||||
# Needed to build VLLM & flash.
|
# Needed to build VLLM & flash.
|
||||||
rocthrust-dev \
|
rocthrust-dev \
|
||||||
hipsparse-dev \
|
hipsparse-dev \
|
||||||
hipblas-dev && \
|
hipblas-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
# Keep in sync with `server/pyproject.toml
|
||||||
RUN wget \
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
|
||||||
&& mkdir .conda \
|
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
|
||||||
ARG PATH="/root/user/miniconda3/bin:${PATH}"
|
|
||||||
RUN conda init bash
|
|
||||||
|
|
||||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||||
ARG ROCM_VERSION='5.7'
|
ARG ROCM_VERSION='5.7'
|
||||||
ARG PYTHON_VERSION='3.11.5'
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
mamba init && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||||
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||||
@ -106,13 +112,13 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
# Copy build artifacts from flash attention v2 builder
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
@ -10,7 +10,7 @@ build-vllm-rocm: build-vllm
|
|||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm: vllm
|
||||||
|
@ -90,6 +90,9 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
|
||||||
|
|
||||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
|
@ -1,16 +1,4 @@
|
|||||||
import subprocess
|
import torch
|
||||||
|
|
||||||
IS_CUDA_SYSTEM = False
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
IS_ROCM_SYSTEM = False
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.check_output("nvidia-smi")
|
|
||||||
IS_CUDA_SYSTEM = True
|
|
||||||
except Exception:
|
|
||||||
IS_CUDA_SYSTEM = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.check_output("rocm-smi")
|
|
||||||
IS_ROCM_SYSTEM = True
|
|
||||||
except Exception:
|
|
||||||
IS_ROCM_SYSTEM = False
|
|
||||||
|
Loading…
Reference in New Issue
Block a user