clean rocm support

This commit is contained in:
Felix Marty 2023-11-07 14:56:11 +00:00
parent 52bdcf797d
commit ea8438a5a0
20 changed files with 378 additions and 263 deletions

View File

@ -107,7 +107,7 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2 # Build specific version of flash attention v2
RUN make build-flash-attention-v2 RUN make build-flash-attention-v2-cuda
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder FROM kernel-builder as exllama-kernels-builder
@ -145,7 +145,7 @@ WORKDIR /usr/src
COPY server/Makefile-vllm Makefile COPY server/Makefile-vllm Makefile
# Build specific version of vllm # Build specific version of vllm
RUN make build-vllm RUN make build-vllm-cuda
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -200,7 +200,8 @@ COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements.txt && \ pip install -r requirements_common.txt && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize]" --no-cache-dir pip install ".[bnb, accelerate, quantize]" --no-cache-dir
# Install benchmarker # Install benchmarker
@ -215,7 +216,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
g++ \ g++ \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# AWS Sagemaker compatbile image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh COPY sagemaker-entrypoint.sh entrypoint.sh

View File

@ -35,7 +35,7 @@ COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --release
# Text Generation Inference base image # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-20.04:5.7 as base FROM rocm/dev-ubuntu-20.04:5.7 as base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
@ -48,14 +48,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
libssl-dev \ libssl-dev \
g++ \ g++ \
wget \ wget \
# Needed to build VLLM. # Needed to build VLLM & flash.
rocthrust-dev \ rocthrust-dev \
hipsparse-dev \ hipsparse-dev \
hipblas-dev && \ hipblas-dev && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
RUN wget \ RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir .conda \ && mkdir .conda \
@ -70,66 +69,80 @@ ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7' ARG ROCM_VERSION='5.7'
ARG PYTHON_VERSION='3.11.5' ARG PYTHON_VERSION='3.11.5'
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install -U ninja RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
FROM base AS kernel-builder
# Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src WORKDIR /usr/src
# Install VLLM. COPY server/Makefile-vllm Makefile
RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm
WORKDIR /usr/src/vllm-public
RUN pip install -r requirements.txt
RUN python setup.py install
# Install Flash Attention v1. # Build specific version of vllm
RUN make build-vllm-rocm
# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src WORKDIR /usr/src
RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install
# Not working for RoCm COPY server/Makefile-flash-att-v2 Makefile
# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build
# COPY server/Makefile-flash-att Makefile # Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm
# Build specific version of flash attention # Build Transformers CUDA kernels (gpt-neox and bloom)
# RUN make build-flash-attention FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
# Build Transformers CUDA kernels FROM base as base-copy
# NOTE: gpt-neox and bloom fused kernels
# FROM kernel-builder as custom-kernels-builder
# WORKDIR /usr/src
# COPY server/custom_kernels/ .
# Build specific version of transformers
# RUN python setup.py build
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
# Copy build artifacts from flash attention builder # Copy builds artifacts from vllm builder
# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
RUN cd server && pip3 install -r requirements.txt RUN cd server && pip install -r requirements_common.txt && \
pip install -r requirements_rocm.txt
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && pip install ".[accelerate]" --no-cache-dir
pip3 install ".[accelerate]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcherg # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# ENTRYPOINT ["text-generation-launcher"] # AWS Sagemaker compatible image
# CMD ["--json-output"] FROM base-copy as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -72,7 +72,9 @@ curl 127.0.0.1:8080/generate \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 [to some extent](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0+rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -183,7 +185,7 @@ sudo apt-get install libssl-dev gcc -y
### CUDA Kernels ### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default. Be aware that the official Docker image has them enabled by default.

View File

@ -254,7 +254,7 @@ Options:
## DISABLE_CUSTOM_KERNELS ## DISABLE_CUSTOM_KERNELS
```shell ```shell
--disable-custom-kernels --disable-custom-kernels
For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on A100. Use this flag to disable them if you're running on different hardware and encounter issues For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on Nvidia A100, AMD MI210 and AMD MI250. Use this flag to disable them if you're running on different hardware and encounter issues
[env: DISABLE_CUSTOM_KERNELS=] [env: DISABLE_CUSTOM_KERNELS=]

View File

@ -15,6 +15,8 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.1.1+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html).
</Tip> </Tip>
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.

View File

@ -39,9 +39,9 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
## Supported Hardware ## Supported Hardware
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has experimental support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization, flash [rotary embedding kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary), flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm).
TGI is also supported on the following AI hardware accelerators: TGI is also supported on the following AI hardware accelerators:
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index) - *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)

View File

@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash-attention: flash-attention:
# Clone flash attention # Clone flash attention
pip install packaging pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git git clone https://github.com/HazyResearch/flash-attention.git
build-flash-attention: flash-attention build-flash-attention: flash-attention

View File

@ -1,13 +1,26 @@
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
build-flash-attention-v2-cuda: FLASH_ATTN_V2_COMMIT=02ac572f3ffc4f402e4183aaa6824b45859d3ed3
build-flash-attention-v2-cuda: FLASH_REPOSITORY=https://github.com/HazyResearch/flash-attention.git
build-flash-attention-v2-cuda: BRANCH=main
build-flash-attention-v2-cuda: PYTORCH_ROCM_ARCH=""
build-flash-attention-v2-cuda: build-flash-attention-v2
build-flash-attention-v2-rocm: FLASH_ATTN_V2_COMMIT=8736558c287ff2ef28b24878e42828c595ac3e69
build-flash-attention-v2-rocm: FLASH_REPOSITORY=https://github.com/fxmarty/flash-attention-rocm
build-flash-attention-v2-rocm: BRANCH=remove-offload-arch-native
build-flash-attention-v2-rocm: PYTORCH_ROCM_ARCH=gfx90a
build-flash-attention-v2-rocm: build-flash-attention-v2
flash-attention-v2: flash-attention-v2:
# Clone flash attention # Clone flash attention
pip install packaging pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 git clone --single-branch --branch $(BRANCH) $(FLASH_REPOSITORY) flash-attention-v2
build-flash-attention-v2: flash-attention-v2 build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) cd flash-attention-v2 && git fetch && git checkout $(FLASH_ATTN_V2_COMMIT)
cd flash-attention-v2 && python setup.py build cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=$(PYTORCH_ROCM_ARCH) python setup.py build
install-flash-attention-v2: build-flash-attention-v2 install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install

View File

@ -1,11 +1,20 @@
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=65f4a79621b4d992cf97f6b84598804eb4ca87b6
build-vllm-rocm: BRANCH=port-to-rocm
build-vllm-rocm: build-vllm
vllm: vllm:
# Clone vllm # Clone vllm
git clone https://github.com/vllm-project/vllm.git pip install -U ninja --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
build-vllm: vllm build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit) cd vllm && git fetch && git checkout $(VLLM_COMMIT)
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm: build-vllm install-vllm: build-vllm

View File

@ -1,5 +1,10 @@
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_compile_args = ["-std=c++17"]
if not torch.version.hip:
extra_compile_args.append("-arch=compute_80")
setup( setup(
name="custom_kernels", name="custom_kernels",
@ -7,12 +12,12 @@ setup(
CUDAExtension( CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda", name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"], sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"], extra_compile_args=extra_compile_args,
), ),
CUDAExtension( CUDAExtension(
name="custom_kernels.fused_attention_cuda", name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"], sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"], extra_compile_args=extra_compile_args,
), ),
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},

View File

@ -4,8 +4,6 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
# bitsandbytes is broken on RoCm systems
# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
@ -63,8 +61,6 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13" texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
# We use nightly
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,3 @@
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
# bitsandbytes can not compile on RoCm systems, hence only installed for Nvidia GPUs
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,2 @@
# We use nightly
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -28,9 +28,6 @@ from typing import Optional, List, Tuple
from loguru import logger from loguru import logger
# Flash attention imports
# import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -40,8 +37,12 @@ from text_generation_server.utils.layers import (
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from vllm import layernorm_ops if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
torch.set_printoptions(threshold=10000000, sci_mode=True) torch.set_printoptions(threshold=10000000, sci_mode=True)
@ -125,8 +126,31 @@ class LlamaRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
else: elif IS_CUDA_SYSTEM:
# We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used. # faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
@ -140,30 +164,6 @@ class LlamaRMSNorm(nn.Module):
) )
return out, residual return out, residual
# else:
# # faster post attention rms norm
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# None,
# None,
# None,
# None,
# None,
# 0.0,
# self.variance_epsilon,
# 1.0,
# 0,
# None,
# False,
# True, # Activate RMSNorm
# )
# if res is None:
# res = hidden_states
# return normed_hidden_states, res
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
@ -282,11 +282,6 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# logger.info(f"query before rotary {query[:10, ..., :8]}")
# logger.info(f"cos before rotary {cos[:10]}")
# logger.info(f"sin before rotary {sin[:10]}")
# TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change...
# Flash Attention kernel may be usable since it is Triton-based
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
@ -297,9 +292,6 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# logger.info(f"query {query.shape}")
# logger.info(f"query piece {query[:10, ..., :8]}")
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
@ -326,9 +318,6 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
) )
# logger.info(f"attn_output {attn_output.shape}")
# logger.info(f"attn_output piece {attn_output[:10, ..., :8]}")
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -26,11 +26,8 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
# import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_ROCM
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -39,8 +36,14 @@ from text_generation_server.utils.layers import (
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if not HAS_FLASH_ATTN_V2: if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_ROCM and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2") raise ImportError("Mistral model requires flash attn v2")
@ -110,45 +113,59 @@ class MistralRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
# else: elif IS_CUDA_SYSTEM:
# # faster post attention rms norm # faster post attention rms norm
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states, hidden_states,
# residual, residual,
# self.weight, self.weight,
# None, None,
# None, None,
# None, None,
# None, None,
# None, None,
# 0.0, 0.0,
# self.variance_epsilon, self.variance_epsilon,
# 1.0, 1.0,
# 0, 0,
# None, None,
# False, False,
# True, # Activate RMSNorm True, # Activate RMSNorm
# ) )
# if res is None: if res is None:
# res = hidden_states res = hidden_states
# return normed_hidden_states, res return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):

View File

@ -55,8 +55,12 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
# import dropout_layer_norm from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
@dataclass @dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast): class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
@ -354,54 +358,80 @@ class IdeficsRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states
# else: elif IS_CUDA_SYSTEM:
# # faster post attention rms norm # faster post attention rms norm
# unwrap = False unwrap = False
# if len(hidden_states.shape) > 2: if len(hidden_states.shape) > 2:
# unwrap = True unwrap = True
# shape = hidden_states.shape shape = hidden_states.shape
# hidden_states = hidden_states.reshape(-1, shape[-1]) hidden_states = hidden_states.reshape(-1, shape[-1])
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states, hidden_states,
# residual, residual,
# self.weight, self.weight,
# None, None,
# None, None,
# None, None,
# None, None,
# None, None,
# 0.0, 0.0,
# self.variance_epsilon, self.variance_epsilon,
# 1.0, 1.0,
# 0, 0,
# None, None,
# False, False,
# True, # Activate RMSNorm True, # Activate RMSNorm
# ) )
# if res is None: if res is None:
# res = hidden_states res = hidden_states
# if unwrap: if unwrap:
# normed_hidden_states = normed_hidden_states.view(*shape) normed_hidden_states = normed_hidden_states.view(*shape)
# return normed_hidden_states return normed_hidden_states
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
if res is None:
res = hidden_states
if unwrap:
out = out.view(*shape)
return out
# this was adapted from LlamaMLP # this was adapted from LlamaMLP

View File

@ -3,7 +3,7 @@ import torch
from loguru import logger from loguru import logger
from .import_utils import is_cuda_system, is_rocm_system from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -17,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0 is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try: try:
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
@ -32,7 +33,8 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported for " f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2" "Flash Attention V2"
) )
HAS_FLASH_ATTN_V2 = True HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e: except ImportError as e:
try: try:
import flash_attn_cuda import flash_attn_cuda
@ -43,11 +45,11 @@ except ImportError as e:
"or install flash attention with `cd server && make install install-flash-attention`" "or install flash attention with `cd server && make install install-flash-attention`"
) from e ) from e
if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90): if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) from e ) from e
elif is_rocm_system(): elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError( raise ImportError(
@ -69,7 +71,7 @@ def attention(
window_size_left=-1, window_size_left=-1,
): ):
# logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}") # logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}")
if HAS_FLASH_ATTN_V2: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, k,
@ -88,9 +90,25 @@ def attention(
False, False,
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM:
# logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}") # RoCm flash API does not take the window_size_left and window_size_right arguments.
if HAS_FLASH_ATTN: return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN:
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
@ -135,8 +153,7 @@ def attention(
softmax_scale, softmax_scale,
False, False,
True, True,
False, # is_deterministic => rocm specific argument False,
False, # return_softmax
0, 0,
None, None,
) )

View File

@ -1,15 +1,16 @@
import subprocess import subprocess
def is_cuda_system(): IS_CUDA_SYSTEM = False
try: IS_ROCM_SYSTEM = False
subprocess.check_output("nvidia-smi")
return True
except Exception:
return False
def is_rocm_system(): try:
try: subprocess.check_output("nvidia-smi")
subprocess.check_output("rocm-smi") IS_CUDA_SYSTEM = True
return True except Exception:
except Exception: IS_CUDA_SYSTEM = False
return False
try:
subprocess.check_output("rocm-smi")
IS_ROCM_SYSTEM = True
except Exception:
IS_ROCM_SYSTEM = False

View File

@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
except ImportError: except ImportError:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
HAS_AWQ = True HAS_AWQ = True
try: try:
@ -509,50 +508,80 @@ class TensorParallelEmbedding(nn.Module):
try: try:
# import dropout_layer_norm if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None: # Mistral does not use RMSNorm.
hidden_states += residual if residual is not None:
residual = hidden_states hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual return super(FastLayerNorm, self).forward(hidden_states), residual
# else: else:
# ( (
# normed_hidden_states, normed_hidden_states,
# residual, residual,
# *rest, *rest,
# ) = dropout_layer_norm.dropout_add_ln_fwd( ) = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states, hidden_states,
# residual, residual,
# self.weight, self.weight,
# self.bias, self.bias,
# None, None,
# None, None,
# None, None,
# None, None,
# 0.0, 0.0,
# self.eps, self.eps,
# 1.0, 1.0,
# 0, 0,
# None, None,
# False, False,
# False, False,
# ) )
# if residual is None: if residual is None:
# residual = hidden_states residual = hidden_states
# return normed_hidden_states, residual
return normed_hidden_states, residual
except ImportError: except ImportError:
pass pass
try: try:
# from flash_attn.layers.rotary import RotaryEmbedding if IS_CUDA_SYSTEM:
# import rotary_emb from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
def rope_forward_cuda(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
elif IS_ROCM_SYSTEM:
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
dtype = x.dtype
x_upcast = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x1 = x_upcast[..., :rotary_dim]
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
return x
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
@ -690,21 +719,10 @@ try:
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1] PositionRotaryEmbedding.forward = rope_forward_cuda
elif IS_ROCM_SYSTEM:
dtype = x.dtype PositionRotaryEmbedding.forward = rope_forward_rocm
x_upcast = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x1 = x_upcast[..., :rotary_dim]
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
# rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
# Flash Attention kernel casts everything to float, not sure why. In place op here
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):

View File

@ -4,8 +4,6 @@ import torch
from vllm import cache_ops from vllm import cache_ops
from vllm import attention_ops from vllm import attention_ops
from loguru import logger
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -56,7 +54,6 @@ def attention(
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
logger.info(f"paged attention use_v1 {use_v1}")
if use_v1: if use_v1:
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
out, out,