text-generation-inference/server/custom_kernels/setup.py
fxmarty b2b5df0e94
Add RoCm support (#1243)
This PR adds support for AMD Instinct MI210 & MI250 GPUs, with paged
attention and FAv2 support.

Remaining items to discuss, on top of possible others:
* Should we have a
`ghcr.io/huggingface/text-generation-inference:1.1.0+rocm` hosted image,
or is it too early?
* Should we set up a CI on MI210/MI250? I don't have access to the
runners of TGI though.
* Are we comfortable with those changes being directly in TGI, or do we
need a fork?

---------

Co-authored-by: Felix Marty <felix@hf.co>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: Your Name <you@example.com>
2023-11-27 14:08:12 +01:00

25 lines
751 B
Python

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_compile_args = ["-std=c++17"]
if not torch.version.hip:
extra_compile_args.append("-arch=compute_80")
setup(
name="custom_kernels",
ext_modules=[
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=extra_compile_args,
),
CUDAExtension(
name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=extra_compile_args,
),
],
cmdclass={"build_ext": BuildExtension},
)