mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
287 lines
8.1 KiB
Python
287 lines
8.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional, Tuple, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
try:
|
|
import marlin_kernels
|
|
except ImportError:
|
|
marlin_kernels = None
|
|
|
|
try:
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
has_sm_8_0 = major >= 8
|
|
except Exception:
|
|
has_sm_8_0 = False
|
|
|
|
|
|
GPTQ_MARLIN_BITS = [4, 8]
|
|
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|
MARLIN_TILE_SIZE = 16
|
|
|
|
|
|
def _check_marlin_kernels():
|
|
if not (SYSTEM == "cuda" and has_sm_8_0):
|
|
raise NotImplementedError(
|
|
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
|
)
|
|
|
|
if marlin_kernels is None:
|
|
raise NotImplementedError(
|
|
"marlin is not installed, install it with: pip install server/marlin"
|
|
)
|
|
|
|
|
|
def _check_valid_shape(in_features: int, out_features: int):
|
|
if (in_features % 128 != 0 or out_features % 64 != 0) and (
|
|
in_features % 64 != 0 or out_features % 128 != 0
|
|
):
|
|
raise ValueError(
|
|
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
|
|
" The shape elements must be divisible by (128, 64) or (64, 128)."
|
|
)
|
|
|
|
|
|
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
|
|
def _get_perms() -> Tuple[List[int], List[int]]:
|
|
scale_perm = []
|
|
for i in range(8):
|
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
|
scale_perm_single = []
|
|
for i in range(4):
|
|
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
|
return scale_perm, scale_perm_single
|
|
|
|
|
|
_scale_perm, _scale_perm_single = _get_perms()
|
|
|
|
|
|
def permute_scales(scales: torch.Tensor):
|
|
out_features = scales.shape[1]
|
|
if scales.shape[0] == 1:
|
|
scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
|
|
else:
|
|
scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
|
return scales.reshape((-1, out_features)).contiguous()
|
|
|
|
|
|
@dataclass
|
|
class GPTQMarlinWeight:
|
|
"""
|
|
Repacked GPTQ Marlin weights.
|
|
"""
|
|
|
|
qweight: torch.Tensor
|
|
scales: torch.Tensor
|
|
g_idx: torch.Tensor
|
|
perm: torch.Tensor
|
|
bits: int
|
|
is_full_k: bool
|
|
|
|
def __post_init__(self):
|
|
assert self.qweight.dtype == torch.int32
|
|
assert self.scales.dtype == torch.float16
|
|
assert self.g_idx.dtype == torch.int32
|
|
assert self.perm.dtype == torch.int32
|
|
|
|
|
|
def repack_gptq_for_marlin(
|
|
*,
|
|
qweight: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
g_idx: torch.Tensor,
|
|
bits: int,
|
|
desc_act: bool,
|
|
groupsize: int,
|
|
sym: bool,
|
|
sharded_infeatures: bool,
|
|
) -> GPTQMarlinWeight:
|
|
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
|
_check_marlin_kernels()
|
|
assert marlin_kernels is not None
|
|
|
|
if bits not in GPTQ_MARLIN_BITS:
|
|
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
|
raise RuntimeError(
|
|
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
|
|
)
|
|
|
|
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
|
|
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
|
|
raise RuntimeError(
|
|
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
|
)
|
|
if not sym:
|
|
raise RuntimeError(
|
|
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
|
)
|
|
|
|
weights_per_int = 32 // bits
|
|
in_features = qweight.shape[0] * weights_per_int
|
|
out_features = qweight.shape[1]
|
|
|
|
if in_features % groupsize != 0:
|
|
raise ValueError(
|
|
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
|
)
|
|
|
|
if desc_act and groupsize != -1:
|
|
perm = torch.argsort(g_idx).to(torch.int)
|
|
g_idx = g_idx[perm]
|
|
else:
|
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
|
|
repacked = marlin_kernels.gptq_marlin_repack(
|
|
qweight, perm, in_features, out_features, bits
|
|
)
|
|
|
|
scales = permute_scales(scales)
|
|
|
|
is_full_k = not (desc_act and sharded_infeatures)
|
|
|
|
return GPTQMarlinWeight(
|
|
qweight=repacked,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
perm=perm,
|
|
bits=bits,
|
|
is_full_k=is_full_k,
|
|
)
|
|
|
|
|
|
class GPTQMarlinLinear(nn.Module):
|
|
"""
|
|
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
|
|
kernels.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
weight: GPTQMarlinWeight,
|
|
bias: Optional[torch.Tensor],
|
|
):
|
|
super().__init__()
|
|
|
|
_check_marlin_kernels()
|
|
assert marlin_kernels is not None
|
|
|
|
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
|
out_features = weight.scales.shape[1]
|
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
|
|
|
self.bits = weight.bits
|
|
self.is_full_k = weight.is_full_k
|
|
|
|
self.register_buffer("qweight", weight.qweight)
|
|
self.register_buffer("scales", weight.scales)
|
|
self.register_buffer("g_idx", weight.g_idx)
|
|
self.register_buffer("perm", weight.perm)
|
|
if bias is not None:
|
|
self.register_buffer("bias", bias)
|
|
else:
|
|
self.bias = None
|
|
|
|
self.workspace = torch.zeros(
|
|
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
|
|
)
|
|
|
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
assert marlin_kernels is not None
|
|
|
|
A_flat = A.view(-1, A.shape[-1])
|
|
C = marlin_kernels.gptq_marlin_gemm(
|
|
A_flat,
|
|
self.qweight,
|
|
self.scales,
|
|
self.g_idx,
|
|
self.perm,
|
|
self.workspace,
|
|
self.bits,
|
|
A_flat.shape[0],
|
|
self.scales.shape[1],
|
|
A_flat.shape[1],
|
|
self.is_full_k,
|
|
)
|
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
|
|
|
if self.bias is not None:
|
|
C += self.bias
|
|
|
|
return C
|
|
|
|
|
|
@dataclass
|
|
class MarlinWeight:
|
|
"""
|
|
Marlin weights.
|
|
|
|
Attributes:
|
|
B (torch.Tensor): int4-quantized weights packed into int32.
|
|
s (torch.Tensor): float16 scales.
|
|
"""
|
|
|
|
B: torch.Tensor
|
|
s: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
assert self.B.dtype == torch.int32
|
|
assert self.s.dtype == torch.float16
|
|
|
|
|
|
class MarlinLinear(nn.Module):
|
|
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
|
|
super().__init__()
|
|
|
|
_check_marlin_kernels()
|
|
assert marlin_kernels is not None
|
|
|
|
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
|
out_features = weight.s.shape[1]
|
|
assert (
|
|
in_features % 128 == 0
|
|
), f"Number of input features ({in_features}) not divisable by 128"
|
|
assert (
|
|
out_features % 256 == 0
|
|
), f"Number of output features ({out_features}) not divisable by 256"
|
|
|
|
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
|
assert groupsize in {
|
|
-1,
|
|
128,
|
|
}, f"Group size must be -1 or 128, was {groupsize}"
|
|
|
|
self.register_buffer("B", weight.B)
|
|
self.register_buffer("s", weight.s)
|
|
if bias is not None:
|
|
self.register_buffer("bias", bias)
|
|
else:
|
|
self.bias = None
|
|
|
|
self.workspace = torch.zeros(
|
|
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
|
|
)
|
|
|
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
assert marlin_kernels is not None
|
|
|
|
C = marlin_kernels.marlin_gemm(
|
|
A.view(-1, A.shape[-1]),
|
|
self.B,
|
|
self.s,
|
|
self.workspace,
|
|
A.shape[0],
|
|
self.s.shape[1],
|
|
A.shape[1],
|
|
)
|
|
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
|
|
|
if self.bias is not None:
|
|
C += self.bias
|
|
|
|
return C
|