text-generation-inference/server/text_generation_server/layers/marlin/util.py
Daniël de Kok 571ac9b507
Use kernels from the kernel hub (#2988)
* Use Hub kernels for Marlin and cutlass quantization kernels

* Use hub kernels for MoE/GPTQ-Marlin MoE

* Use attention kernels from the Hub

* Cache the kernels in the Docker image

* Update moe kernels

* Support loading local kernels for development

* Support latest moe kernels

* Update to moe 0.1.1

* CI: download locked kernels for server tests

* Fixup some imports

* CI: activate venv

* Fix unused imports

* Nix: add attention/moe/quantization kernels

* Update hf-kernels to 0.1.5

* Update kernels

* Update tgi-nix flake for hf-kernels

* Fix EOF

* Take `load_kernel` out of a frequently-called function

* Hoist another case of kernel loading out of a somewhat hot function

* marlin-kernels -> quantization

* attention -> paged-attention

* EOF fix

* Update hf-kernels, fixup Docker

* ipex fix

* Remove outdated TODO
2025-02-10 19:19:25 +01:00

145 lines
4.1 KiB
Python

import functools
from typing import List, Tuple
import numpy
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if quantization is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@functools.cache
def get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def permute_scales(scales: torch.Tensor):
scale_perm, scale_perm_single = get_perms()
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
else:
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
return scales.reshape((-1, out_features)).contiguous()
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
scale_perm, _ = get_perms()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp