text-generation-inference/server/text_generation_server/layers/marlin.py
Daniël de Kok 093a27c528
Add support for GPTQ Marlin (#2052)
Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
2024-06-14 09:45:42 +02:00

287 lines
8.1 KiB
Python

from dataclasses import dataclass
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
def _get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
_scale_perm, _scale_perm_single = _get_perms()
def permute_scales(scales: torch.Tensor):
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
else:
scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm]
return scales.reshape((-1, out_features)).contiguous()
@dataclass
class GPTQMarlinWeight:
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
scales: torch.Tensor,
g_idx: torch.Tensor,
bits: int,
desc_act: bool,
groupsize: int,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not sym:
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
weights_per_int = 32 // bits
in_features = qweight.shape[0] * weights_per_int
out_features = qweight.shape[1]
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
scales = permute_scales(scales)
is_full_k = not (desc_act and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.register_buffer("qweight", weight.qweight)
self.register_buffer("scales", weight.scales)
self.register_buffer("g_idx", weight.g_idx)
self.register_buffer("perm", weight.perm)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
@dataclass
class MarlinWeight:
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.register_buffer("B", weight.B)
self.register_buffer("s", weight.s)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C