mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 04:52:07 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
45 lines
935 B
Python
45 lines
935 B
Python
import torch
|
|
|
|
def gptq_marlin_gemm(
|
|
a: torch.Tensor,
|
|
b_q_weight: torch.Tensor,
|
|
b_scales: torch.Tensor,
|
|
g_idx: torch.Tensor,
|
|
perm: torch.Tensor,
|
|
workspace: torch.Tensor,
|
|
num_bits: int,
|
|
size_m: int,
|
|
size_n: int,
|
|
size_k: int,
|
|
is_k_full: bool,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Matrix multiplication using Marlin kernels. This is an extension of
|
|
`marlin_gemm` that supports converted GPTQ kernels.
|
|
"""
|
|
...
|
|
|
|
def gptq_marlin_repack(
|
|
b_q_weight: torch.Tensor,
|
|
perm: torch.Tensor,
|
|
size_k: int,
|
|
size_n: int,
|
|
num_bits: int,
|
|
) -> torch.Tensor:
|
|
"""Repack GPTQ parameters for Marlin kernels."""
|
|
...
|
|
|
|
def marlin_gemm(
|
|
a: torch.Tensor,
|
|
b_q_weight: torch.Tensor,
|
|
b_scales: torch.Tensor,
|
|
workspace: torch.Tensor,
|
|
size_m: int,
|
|
size_n: int,
|
|
size_k: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Matrix multiplication using Marlin kernels.
|
|
"""
|
|
...
|