text-generation-inference/server/marlin/marlin_kernels/__init__.pyi
Daniël de Kok 85c3c5d64f Add support for FP8 on compute capability >=8.0, <8.9 (#2213)
Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs
with compute capability >=8.0 and <8.9.

Co-authored-by: Florian Zimmermeister <flozi00.fz@gmail.com>
2024-09-25 05:27:40 +00:00

77 lines
1.6 KiB
Python

import torch
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports converted GPTQ kernels.
"""
...
def gptq_marlin_24_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_meta: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports 2:4 sparsity.
"""
...
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ parameters for Marlin kernels."""
...
def marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels.
"""
...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)