clean rocm support

This commit is contained in:
Felix Marty 2023-11-07 14:56:11 +00:00
parent 52bdcf797d
commit ea8438a5a0
20 changed files with 378 additions and 263 deletions

View File

@ -107,7 +107,7 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2
RUN make build-flash-attention-v2-cuda
# Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder
@ -145,7 +145,7 @@ WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
RUN make build-vllm-cuda
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -200,7 +200,8 @@ COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements.txt && \
pip install -r requirements_common.txt && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize]" --no-cache-dir
# Install benchmarker
@ -215,7 +216,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
g++ \
&& rm -rf /var/lib/apt/lists/*
# AWS Sagemaker compatbile image
# AWS Sagemaker compatible image
FROM base as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh

View File

@ -35,7 +35,7 @@ COPY router router
COPY launcher launcher
RUN cargo build --release
# Text Generation Inference base image
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-20.04:5.7 as base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
@ -48,14 +48,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
libssl-dev \
g++ \
wget \
# Needed to build VLLM.
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev && \
rm -rf /var/lib/apt/lists/*
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir .conda \
@ -70,66 +69,80 @@ ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7'
ARG PYTHON_VERSION='3.11.5'
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
RUN pip install -U ninja
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
FROM base AS kernel-builder
# Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
# Install VLLM.
RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm
WORKDIR /usr/src/vllm-public
RUN pip install -r requirements.txt
RUN python setup.py install
COPY server/Makefile-vllm Makefile
# Install Flash Attention v1.
# Build specific version of vllm
RUN make build-vllm-rocm
# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install
# Not working for RoCm
# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build
COPY server/Makefile-flash-att-v2 Makefile
# COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm
# Build specific version of flash attention
# RUN make build-flash-attention
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
# Build Transformers CUDA kernels
# NOTE: gpt-neox and bloom fused kernels
# FROM kernel-builder as custom-kernels-builder
# WORKDIR /usr/src
# COPY server/custom_kernels/ .
# Build specific version of transformers
# RUN python setup.py build
FROM base as base-copy
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Copy build artifacts from flash attention builder
# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && pip3 install -r requirements.txt
RUN cd server && pip install -r requirements_common.txt && \
pip install -r requirements_rocm.txt
RUN cd server && \
make gen-server && \
pip3 install ".[accelerate]" --no-cache-dir
make gen-server && pip install ".[accelerate]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcherg
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# ENTRYPOINT ["text-generation-launcher"]
# CMD ["--json-output"]
# AWS Sagemaker compatible image
FROM base-copy as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -72,7 +72,9 @@ curl 127.0.0.1:8080/generate \
-H 'Content-Type: application/json'
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 [to some extent](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0+rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -183,7 +185,7 @@ sudo apt-get install libssl-dev gcc -y
### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default.

View File

@ -254,7 +254,7 @@ Options:
## DISABLE_CUSTOM_KERNELS
```shell
--disable-custom-kernels
For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on A100. Use this flag to disable them if you're running on different hardware and encounter issues
For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on Nvidia A100, AMD MI210 and AMD MI250. Use this flag to disable them if you're running on different hardware and encounter issues
[env: DISABLE_CUSTOM_KERNELS=]

View File

@ -15,6 +15,8 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.1.1+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html).
</Tip>
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.

View File

@ -39,9 +39,9 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
## Supported Hardware
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has experimental support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization, flash [rotary embedding kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary), flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm).
TGI is also supported on the following AI hardware accelerators:
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)

View File

@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash-attention:
# Clone flash attention
pip install packaging
pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git
build-flash-attention: flash-attention

View File

@ -1,13 +1,26 @@
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
build-flash-attention-v2-cuda: FLASH_ATTN_V2_COMMIT=02ac572f3ffc4f402e4183aaa6824b45859d3ed3
build-flash-attention-v2-cuda: FLASH_REPOSITORY=https://github.com/HazyResearch/flash-attention.git
build-flash-attention-v2-cuda: BRANCH=main
build-flash-attention-v2-cuda: PYTORCH_ROCM_ARCH=""
build-flash-attention-v2-cuda: build-flash-attention-v2
build-flash-attention-v2-rocm: FLASH_ATTN_V2_COMMIT=8736558c287ff2ef28b24878e42828c595ac3e69
build-flash-attention-v2-rocm: FLASH_REPOSITORY=https://github.com/fxmarty/flash-attention-rocm
build-flash-attention-v2-rocm: BRANCH=remove-offload-arch-native
build-flash-attention-v2-rocm: PYTORCH_ROCM_ARCH=gfx90a
build-flash-attention-v2-rocm: build-flash-attention-v2
flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
pip install -U packaging ninja --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(FLASH_REPOSITORY) flash-attention-v2
build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build
cd flash-attention-v2 && git fetch && git checkout $(FLASH_ATTN_V2_COMMIT)
cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=$(PYTORCH_ROCM_ARCH) python setup.py build
install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install

View File

@ -1,11 +1,20 @@
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=65f4a79621b4d992cf97f6b84598804eb4ca87b6
build-vllm-rocm: BRANCH=port-to-rocm
build-vllm-rocm: build-vllm
vllm:
# Clone vllm
git clone https://github.com/vllm-project/vllm.git
pip install -U ninja --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
cd vllm && python setup.py build
install-vllm: build-vllm

View File

@ -1,5 +1,10 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_compile_args = ["-std=c++17"]
if not torch.version.hip:
extra_compile_args.append("-arch=compute_80")
setup(
name="custom_kernels",
@ -7,12 +12,12 @@ setup(
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
extra_compile_args=extra_compile_args,
),
CUDAExtension(
name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
extra_compile_args=extra_compile_args,
),
],
cmdclass={"build_ext": BuildExtension},

View File

@ -4,8 +4,6 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
# bitsandbytes is broken on RoCm systems
# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
@ -63,8 +61,6 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
# We use nightly
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,3 @@
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
# bitsandbytes can not compile on RoCm systems, hence only installed for Nvidia GPUs
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,2 @@
# We use nightly
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -28,9 +28,6 @@ from typing import Optional, List, Tuple
from loguru import logger
# Flash attention imports
# import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
@ -40,7 +37,11 @@ from text_generation_server.utils.layers import (
TensorParallelHead,
get_linear,
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
torch.set_printoptions(threshold=10000000, sci_mode=True)
@ -125,8 +126,31 @@ class LlamaRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
else:
# We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used.
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
@ -140,30 +164,6 @@ class LlamaRMSNorm(nn.Module):
)
return out, residual
# else:
# # faster post attention rms norm
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# None,
# None,
# None,
# None,
# None,
# 0.0,
# self.variance_epsilon,
# 1.0,
# 0,
# None,
# False,
# True, # Activate RMSNorm
# )
# if res is None:
# res = hidden_states
# return normed_hidden_states, res
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
@ -282,11 +282,6 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# logger.info(f"query before rotary {query[:10, ..., :8]}")
# logger.info(f"cos before rotary {cos[:10]}")
# logger.info(f"sin before rotary {sin[:10]}")
# TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change...
# Flash Attention kernel may be usable since it is Triton-based
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
@ -297,9 +292,6 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor
attn_output = torch.empty_like(query)
# logger.info(f"query {query.shape}")
# logger.info(f"query piece {query[:10, ..., :8]}")
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
@ -326,9 +318,6 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
)
# logger.info(f"attn_output {attn_output.shape}")
# logger.info(f"attn_output piece {attn_output[:10, ..., :8]}")
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -26,11 +26,8 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
# Flash attention imports
# import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_ROCM
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -39,8 +36,14 @@ from text_generation_server.utils.layers import (
TensorParallelHead,
get_linear,
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if not HAS_FLASH_ATTN_V2:
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_ROCM and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")
@ -110,7 +113,7 @@ class MistralRMSNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192:
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
@ -126,29 +129,43 @@ class MistralRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
# else:
# # faster post attention rms norm
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# None,
# None,
# None,
# None,
# None,
# 0.0,
# self.variance_epsilon,
# 1.0,
# 0,
# None,
# False,
# True, # Activate RMSNorm
# )
# if res is None:
# res = hidden_states
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
# return normed_hidden_states, res
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
def load_attention(config, prefix, weights):

View File

@ -55,8 +55,12 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
FastLinear,
)
# import dropout_layer_norm
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
@dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
@ -354,7 +358,7 @@ class IdeficsRMSNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192:
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
@ -370,38 +374,64 @@ class IdeficsRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
# else:
# # faster post attention rms norm
# unwrap = False
# if len(hidden_states.shape) > 2:
# unwrap = True
# shape = hidden_states.shape
# hidden_states = hidden_states.reshape(-1, shape[-1])
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# None,
# None,
# None,
# None,
# None,
# 0.0,
# self.variance_epsilon,
# 1.0,
# 0,
# None,
# False,
# True, # Activate RMSNorm
# )
# if res is None:
# res = hidden_states
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
# if unwrap:
# normed_hidden_states = normed_hidden_states.view(*shape)
if unwrap:
normed_hidden_states = normed_hidden_states.view(*shape)
# return normed_hidden_states
return normed_hidden_states
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
if res is None:
res = hidden_states
if unwrap:
out = out.view(*shape)
return out
# this was adapted from LlamaMLP

View File

@ -3,7 +3,7 @@ import torch
from loguru import logger
from .import_utils import is_cuda_system, is_rocm_system
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -17,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try:
import flash_attn_2_cuda
@ -32,7 +33,8 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2 = True
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
@ -43,11 +45,11 @@ except ImportError as e:
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90):
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif is_rocm_system():
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
@ -69,7 +71,7 @@ def attention(
window_size_left=-1,
):
# logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}")
if HAS_FLASH_ATTN_V2:
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
@ -88,9 +90,25 @@ def attention(
False,
None,
)
# logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}")
if HAS_FLASH_ATTN:
elif HAS_FLASH_ATTN_V2_ROCM:
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
@ -135,8 +153,7 @@ def attention(
softmax_scale,
False,
True,
False, # is_deterministic => rocm specific argument
False, # return_softmax
False,
0,
None,
)

View File

@ -1,15 +1,16 @@
import subprocess
def is_cuda_system():
IS_CUDA_SYSTEM = False
IS_ROCM_SYSTEM = False
try:
subprocess.check_output("nvidia-smi")
return True
IS_CUDA_SYSTEM = True
except Exception:
return False
IS_CUDA_SYSTEM = False
def is_rocm_system():
try:
subprocess.check_output("rocm-smi")
return True
IS_ROCM_SYSTEM = True
except Exception:
return False
IS_ROCM_SYSTEM = False

View File

@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params, Params4bit
except ImportError:
HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
HAS_AWQ = True
try:
@ -509,50 +508,80 @@ class TensorParallelEmbedding(nn.Module):
try:
# import dropout_layer_norm
if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
# if hidden_states.shape[-1] > 8192:
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
# Mistral does not use RMSNorm.
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
# else:
# (
# normed_hidden_states,
# residual,
# *rest,
# ) = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# self.bias,
# None,
# None,
# None,
# None,
# 0.0,
# self.eps,
# 1.0,
# 0,
# None,
# False,
# False,
# )
# if residual is None:
# residual = hidden_states
# return normed_hidden_states, residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
except ImportError:
pass
try:
# from flash_attn.layers.rotary import RotaryEmbedding
# import rotary_emb
if IS_CUDA_SYSTEM:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
def rope_forward_cuda(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
elif IS_ROCM_SYSTEM:
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
dtype = x.dtype
x_upcast = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x1 = x_upcast[..., :rotary_dim]
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
return x
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
@ -690,21 +719,10 @@ try:
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
dtype = x.dtype
x_upcast = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x1 = x_upcast[..., :rotary_dim]
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
# rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
# Flash Attention kernel casts everything to float, not sure why. In place op here
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
if IS_CUDA_SYSTEM:
PositionRotaryEmbedding.forward = rope_forward_cuda
elif IS_ROCM_SYSTEM:
PositionRotaryEmbedding.forward = rope_forward_rocm
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):

View File

@ -4,8 +4,6 @@ import torch
from vllm import cache_ops
from vllm import attention_ops
from loguru import logger
_PARTITION_SIZE = 512
@ -56,7 +54,6 @@ def attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
logger.info(f"paged attention use_v1 {use_v1}")
if use_v1:
attention_ops.paged_attention_v1(
out,