mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
clean rocm support
This commit is contained in:
parent
52bdcf797d
commit
ea8438a5a0
@ -107,7 +107,7 @@ WORKDIR /usr/src
|
||||
COPY server/Makefile-flash-att-v2 Makefile
|
||||
|
||||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2
|
||||
RUN make build-flash-attention-v2-cuda
|
||||
|
||||
# Build Transformers exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
@ -145,7 +145,7 @@ WORKDIR /usr/src
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# Build specific version of vllm
|
||||
RUN make build-vllm
|
||||
RUN make build-vllm-cuda
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
||||
@ -200,7 +200,8 @@ COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install -r requirements_common.txt && \
|
||||
pip install -r requirements_cuda.txt && \
|
||||
pip install ".[bnb, accelerate, quantize]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
@ -215,7 +216,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# AWS Sagemaker compatbile image
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base as sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
|
@ -35,7 +35,7 @@ COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
|
||||
# Text Generation Inference base image
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-20.04:5.7 as base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
@ -48,14 +48,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
libssl-dev \
|
||||
g++ \
|
||||
wget \
|
||||
# Needed to build VLLM.
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
|
||||
RUN wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir .conda \
|
||||
@ -70,66 +69,80 @@ ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||
ARG ROCM_VERSION='5.7'
|
||||
ARG PYTHON_VERSION='3.11.5'
|
||||
|
||||
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||
RUN pip install -U ninja
|
||||
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
# Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Install VLLM.
|
||||
RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm
|
||||
WORKDIR /usr/src/vllm-public
|
||||
RUN pip install -r requirements.txt
|
||||
RUN python setup.py install
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# Install Flash Attention v1.
|
||||
# Build specific version of vllm
|
||||
RUN make build-vllm-rocm
|
||||
|
||||
# Build Flash Attention v2 kernels
|
||||
FROM kernel-builder AS flash-att-v2-builder
|
||||
WORKDIR /usr/src
|
||||
RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install
|
||||
|
||||
# Not working for RoCm
|
||||
# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build
|
||||
COPY server/Makefile-flash-att-v2 Makefile
|
||||
|
||||
# COPY server/Makefile-flash-att Makefile
|
||||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2-rocm
|
||||
|
||||
# Build specific version of flash attention
|
||||
# RUN make build-flash-attention
|
||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
# NOTE: gpt-neox and bloom fused kernels
|
||||
|
||||
# FROM kernel-builder as custom-kernels-builder
|
||||
# WORKDIR /usr/src
|
||||
# COPY server/custom_kernels/ .
|
||||
# Build specific version of transformers
|
||||
# RUN python setup.py build
|
||||
FROM base as base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /root/miniconda3/lib/python3.11/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && pip3 install -r requirements.txt
|
||||
RUN cd server && pip install -r requirements_common.txt && \
|
||||
pip install -r requirements_rocm.txt
|
||||
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip3 install ".[accelerate]" --no-cache-dir
|
||||
make gen-server && pip install ".[accelerate]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcherg
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# ENTRYPOINT ["text-generation-launcher"]
|
||||
# CMD ["--json-output"]
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base-copy as sagemaker
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -72,7 +72,9 @@ curl 127.0.0.1:8080/generate \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 [to some extent](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0+rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -183,7 +185,7 @@ sudo apt-get install libssl-dev gcc -y
|
||||
|
||||
### CUDA Kernels
|
||||
|
||||
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
||||
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
|
||||
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
||||
|
||||
Be aware that the official Docker image has them enabled by default.
|
||||
|
@ -254,7 +254,7 @@ Options:
|
||||
## DISABLE_CUSTOM_KERNELS
|
||||
```shell
|
||||
--disable-custom-kernels
|
||||
For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on A100. Use this flag to disable them if you're running on different hardware and encounter issues
|
||||
For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on Nvidia A100, AMD MI210 and AMD MI250. Use this flag to disable them if you're running on different hardware and encounter issues
|
||||
|
||||
[env: DISABLE_CUSTOM_KERNELS=]
|
||||
|
||||
|
@ -15,6 +15,8 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
|
||||
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||
|
||||
To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.1.1+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html).
|
||||
|
||||
</Tip>
|
||||
|
||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||
|
@ -39,9 +39,9 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
|
||||
TGI also has experimental support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization, flash [rotary embedding kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary), flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm).
|
||||
|
||||
TGI is also supported on the following AI hardware accelerators:
|
||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||
|
||||
flash-attention:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/HazyResearch/flash-attention.git
|
||||
|
||||
build-flash-attention: flash-attention
|
||||
|
@ -1,13 +1,26 @@
|
||||
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
|
||||
build-flash-attention-v2-cuda: FLASH_ATTN_V2_COMMIT=02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
build-flash-attention-v2-cuda: FLASH_REPOSITORY=https://github.com/HazyResearch/flash-attention.git
|
||||
build-flash-attention-v2-cuda: BRANCH=main
|
||||
build-flash-attention-v2-cuda: PYTORCH_ROCM_ARCH=""
|
||||
build-flash-attention-v2-cuda: build-flash-attention-v2
|
||||
|
||||
build-flash-attention-v2-rocm: FLASH_ATTN_V2_COMMIT=8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
build-flash-attention-v2-rocm: FLASH_REPOSITORY=https://github.com/fxmarty/flash-attention-rocm
|
||||
build-flash-attention-v2-rocm: BRANCH=remove-offload-arch-native
|
||||
build-flash-attention-v2-rocm: PYTORCH_ROCM_ARCH=gfx90a
|
||||
build-flash-attention-v2-rocm: build-flash-attention-v2
|
||||
|
||||
flash-attention-v2:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone --single-branch --branch $(BRANCH) $(FLASH_REPOSITORY) flash-attention-v2
|
||||
|
||||
build-flash-attention-v2: flash-attention-v2
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||
cd flash-attention-v2 && python setup.py build
|
||||
cd flash-attention-v2 && git fetch && git checkout $(FLASH_ATTN_V2_COMMIT)
|
||||
cd flash-attention-v2 && git submodule update --init --recursive
|
||||
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=$(PYTORCH_ROCM_ARCH) python setup.py build
|
||||
|
||||
install-flash-attention-v2: build-flash-attention-v2
|
||||
cd flash-attention-v2 && python setup.py install
|
||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
@ -1,11 +1,20 @@
|
||||
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
||||
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
build-vllm-cuda: BRANCH=main
|
||||
build-vllm-cuda: build-vllm
|
||||
|
||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
||||
build-vllm-rocm: VLLM_COMMIT=65f4a79621b4d992cf97f6b84598804eb4ca87b6
|
||||
build-vllm-rocm: BRANCH=port-to-rocm
|
||||
build-vllm-rocm: build-vllm
|
||||
|
||||
vllm:
|
||||
# Clone vllm
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
pip install -U ninja --no-cache-dir
|
||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm: build-vllm
|
||||
|
@ -1,5 +1,10 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_compile_args = ["-std=c++17"]
|
||||
if not torch.version.hip:
|
||||
extra_compile_args.append("-arch=compute_80")
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
@ -7,12 +12,12 @@ setup(
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_bloom_attention_cuda",
|
||||
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_attention_cuda",
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
@ -4,8 +4,6 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
# bitsandbytes is broken on RoCm systems
|
||||
# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -63,8 +61,6 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
||||
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
# We use nightly
|
||||
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
3
server/requirements_cuda.txt
Normal file
3
server/requirements_cuda.txt
Normal file
@ -0,0 +1,3 @@
|
||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
# bitsandbytes can not compile on RoCm systems, hence only installed for Nvidia GPUs
|
||||
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
2
server/requirements_rocm.txt
Normal file
2
server/requirements_rocm.txt
Normal file
@ -0,0 +1,2 @@
|
||||
# We use nightly
|
||||
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
@ -28,9 +28,6 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# Flash attention imports
|
||||
# import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -40,7 +37,11 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
torch.set_printoptions(threshold=10000000, sci_mode=True)
|
||||
@ -125,8 +126,31 @@ class LlamaRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used.
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -140,30 +164,6 @@ class LlamaRMSNorm(nn.Module):
|
||||
)
|
||||
return out, residual
|
||||
|
||||
# else:
|
||||
# # faster post attention rms norm
|
||||
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
# hidden_states,
|
||||
# residual,
|
||||
# self.weight,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# 0.0,
|
||||
# self.variance_epsilon,
|
||||
# 1.0,
|
||||
# 0,
|
||||
# None,
|
||||
# False,
|
||||
# True, # Activate RMSNorm
|
||||
# )
|
||||
# if res is None:
|
||||
# res = hidden_states
|
||||
|
||||
# return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
@ -282,11 +282,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# logger.info(f"query before rotary {query[:10, ..., :8]}")
|
||||
# logger.info(f"cos before rotary {cos[:10]}")
|
||||
# logger.info(f"sin before rotary {sin[:10]}")
|
||||
# TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change...
|
||||
# Flash Attention kernel may be usable since it is Triton-based
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
@ -297,9 +292,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
|
||||
# logger.info(f"query {query.shape}")
|
||||
# logger.info(f"query piece {query[:10, ..., :8]}")
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
@ -326,9 +318,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
# logger.info(f"attn_output {attn_output.shape}")
|
||||
# logger.info(f"attn_output piece {attn_output[:10, ..., :8]}")
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
|
@ -26,11 +26,8 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
# import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_ROCM
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -39,8 +36,14 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if not HAS_FLASH_ATTN_V2:
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
if not HAS_FLASH_ATTN_V2_ROCM and not HAS_FLASH_ATTN_V2_ROCM:
|
||||
raise ImportError("Mistral model requires flash attn v2")
|
||||
|
||||
|
||||
@ -110,7 +113,7 @@ class MistralRMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
# if hidden_states.shape[-1] > 8192:
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -126,29 +129,43 @@ class MistralRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
# else:
|
||||
# # faster post attention rms norm
|
||||
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
# hidden_states,
|
||||
# residual,
|
||||
# self.weight,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# 0.0,
|
||||
# self.variance_epsilon,
|
||||
# 1.0,
|
||||
# 0,
|
||||
# None,
|
||||
# False,
|
||||
# True, # Activate RMSNorm
|
||||
# )
|
||||
# if res is None:
|
||||
# res = hidden_states
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
# return normed_hidden_states, res
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
|
@ -55,8 +55,12 @@ from text_generation_server.utils.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
# import dropout_layer_norm
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||
@ -354,7 +358,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
# if hidden_states.shape[-1] > 8192:
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -370,38 +374,64 @@ class IdeficsRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
# else:
|
||||
# # faster post attention rms norm
|
||||
# unwrap = False
|
||||
# if len(hidden_states.shape) > 2:
|
||||
# unwrap = True
|
||||
# shape = hidden_states.shape
|
||||
# hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
# hidden_states,
|
||||
# residual,
|
||||
# self.weight,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# 0.0,
|
||||
# self.variance_epsilon,
|
||||
# 1.0,
|
||||
# 0,
|
||||
# None,
|
||||
# False,
|
||||
# True, # Activate RMSNorm
|
||||
# )
|
||||
# if res is None:
|
||||
# res = hidden_states
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
# if unwrap:
|
||||
# normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
if unwrap:
|
||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
|
||||
# return normed_hidden_states
|
||||
return normed_hidden_states
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
if unwrap:
|
||||
out = out.view(*shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# this was adapted from LlamaMLP
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .import_utils import is_cuda_system, is_rocm_system
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
@ -17,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2 = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
@ -32,7 +33,8 @@ try:
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2 = True
|
||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||
except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
@ -43,11 +45,11 @@ except ImportError as e:
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
|
||||
if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90):
|
||||
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
elif is_rocm_system():
|
||||
elif IS_ROCM_SYSTEM:
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||
raise ImportError(
|
||||
@ -69,7 +71,7 @@ def attention(
|
||||
window_size_left=-1,
|
||||
):
|
||||
# logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}")
|
||||
if HAS_FLASH_ATTN_V2:
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
@ -88,9 +90,25 @@ def attention(
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
# logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}")
|
||||
if HAS_FLASH_ATTN:
|
||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN:
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
@ -135,8 +153,7 @@ def attention(
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False, # is_deterministic => rocm specific argument
|
||||
False, # return_softmax
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
@ -1,15 +1,16 @@
|
||||
import subprocess
|
||||
|
||||
def is_cuda_system():
|
||||
IS_CUDA_SYSTEM = False
|
||||
IS_ROCM_SYSTEM = False
|
||||
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
return True
|
||||
IS_CUDA_SYSTEM = True
|
||||
except Exception:
|
||||
return False
|
||||
IS_CUDA_SYSTEM = False
|
||||
|
||||
def is_rocm_system():
|
||||
try:
|
||||
subprocess.check_output("rocm-smi")
|
||||
return True
|
||||
IS_ROCM_SYSTEM = True
|
||||
except Exception:
|
||||
return False
|
||||
IS_ROCM_SYSTEM = False
|
@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -509,50 +508,80 @@ class TensorParallelEmbedding(nn.Module):
|
||||
|
||||
|
||||
try:
|
||||
# import dropout_layer_norm
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
# if hidden_states.shape[-1] > 8192:
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
# Mistral does not use RMSNorm.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
# else:
|
||||
# (
|
||||
# normed_hidden_states,
|
||||
# residual,
|
||||
# *rest,
|
||||
# ) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
# hidden_states,
|
||||
# residual,
|
||||
# self.weight,
|
||||
# self.bias,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# 0.0,
|
||||
# self.eps,
|
||||
# 1.0,
|
||||
# 0,
|
||||
# None,
|
||||
# False,
|
||||
# False,
|
||||
# )
|
||||
# if residual is None:
|
||||
# residual = hidden_states
|
||||
|
||||
# return normed_hidden_states, residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# from flash_attn.layers.rotary import RotaryEmbedding
|
||||
# import rotary_emb
|
||||
if IS_CUDA_SYSTEM:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
def rope_forward_cuda(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
|
||||
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
|
||||
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
|
||||
dtype = x.dtype
|
||||
x_upcast = x.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
x1 = x_upcast[..., :rotary_dim]
|
||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||
return x
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
@ -690,21 +719,10 @@ try:
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
|
||||
dtype = x.dtype
|
||||
x_upcast = x.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
x1 = x_upcast[..., :rotary_dim]
|
||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
# rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
# Flash Attention kernel casts everything to float, not sure why. In place op here
|
||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||
if IS_CUDA_SYSTEM:
|
||||
PositionRotaryEmbedding.forward = rope_forward_cuda
|
||||
elif IS_ROCM_SYSTEM:
|
||||
PositionRotaryEmbedding.forward = rope_forward_rocm
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
|
@ -4,8 +4,6 @@ import torch
|
||||
from vllm import cache_ops
|
||||
from vllm import attention_ops
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@ -56,7 +54,6 @@ def attention(
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
logger.info(f"paged attention use_v1 {use_v1}")
|
||||
if use_v1:
|
||||
attention_ops.paged_attention_v1(
|
||||
out,
|
||||
|
Loading…
Reference in New Issue
Block a user