text-generation-inference/server/marlin/setup.py
Daniël de Kok cb150eb295
Add support for FP8 on compute capability >=8.0, <8.9 (#2213)
Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs
with compute capability >=8.0 and <8.9.

Co-authored-by: Florian Zimmermeister <flozi00.fz@gmail.com>
2024-07-11 16:03:26 +02:00

24 lines
702 B
Python

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
extra_compile_args = []
setup(
name="marlin_kernels",
ext_modules=[
CUDAExtension(
name="marlin_kernels",
sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu",
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
"marlin_kernels/ext.cpp",
],
extra_compile_args=extra_compile_args,
),
],
cmdclass={"build_ext": BuildExtension},
)