marlin-kernels -> quantization

This commit is contained in:
Daniël de Kok 2025-02-05 15:10:41 +00:00
parent 8aecc59eb0
commit 8ad383c7cb
6 changed files with 47 additions and 47 deletions

View File

@ -12,11 +12,11 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
class W8A8IntLoader(WeightsLoader):
@ -163,8 +163,8 @@ class Int8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
assert marlin_kernels is not None
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
assert quantization is not None
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
return W8A8IntLinear(
bias=bias,
input_symmetric=self.input_symmetric,
@ -208,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
input=input,
scale=None,
azp=None,
@ -218,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
)
if self.input_symmetric:
return marlin_kernels.cutlass_scaled_mm(
return quantization.cutlass_scaled_mm(
a=qinput,
b=self.weight,
scale_a=input_scale,
@ -233,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
and (self.input_symmetric or input_zero_point is not None)
)
return marlin_kernels.cutlass_scaled_mm_azp(
return quantization.cutlass_scaled_mm_azp(
a=qinput,
b=self.weight,
scale_a=input_scale,

View File

@ -16,11 +16,11 @@ from text_generation_server.utils.weights import (
from text_generation_server.utils.log import log_once
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
try:
# TODO: needs to be ported over to MoE and used on CUDA.
@ -33,9 +33,9 @@ quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None:
if SYSTEM == "cuda" and quantization is not None:
major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
major * 10 + minor
)
else:
@ -147,9 +147,9 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if marlin_kernels is not None:
if quantization is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
qweight, scale = quantization.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
@ -530,7 +530,7 @@ class Fp8Linear(torch.nn.Module):
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)
return marlin_kernels.cutlass_scaled_mm(
return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)

View File

@ -12,11 +12,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
MARLIN_TILE_SIZE = 16
@ -36,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
@ -73,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
C = quantization.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
@ -138,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)

View File

@ -17,11 +17,11 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
try:
@ -41,7 +41,7 @@ def can_use_gptq_marlin(
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and quantization is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
@ -291,7 +291,7 @@ def repack_gptq_for_marlin(
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
@ -334,7 +334,7 @@ def repack_gptq_for_marlin(
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
repacked = quantization.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
@ -346,7 +346,7 @@ def repack_gptq_for_marlin(
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
@ -383,7 +383,7 @@ class GPTQMarlinLinear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
@ -394,14 +394,14 @@ class GPTQMarlinLinear(nn.Module):
if weight.qzeros.numel() > 0:
if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4
self.quant_type = quantization.scalar_types.uint4
else:
self.quant_type = marlin_kernels.scalar_types.uint8
self.quant_type = quantization.scalar_types.uint8
else:
if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4b8
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = marlin_kernels.scalar_types.uint8b128
self.quant_type = quantization.scalar_types.uint8b128
self.is_full_k = weight.is_full_k
@ -420,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
C = quantization.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,

View File

@ -10,11 +10,11 @@ from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
class MarlinWeightsLoader(WeightsLoader):
@ -192,7 +192,7 @@ class MarlinLinear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
@ -221,9 +221,9 @@ class MarlinLinear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
C = marlin_kernels.marlin_gemm(
C = quantization.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
@ -282,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
assert quantization is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
@ -309,9 +309,9 @@ class GPTQMarlin24Linear(nn.Module):
)
if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4b8
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = marlin_kernels.scalar_types.uint8b128
self.quant_type = quantization.scalar_types.uint8b128
weights_per_int32 = 32 // weight.bits
assert (
@ -344,9 +344,9 @@ class GPTQMarlin24Linear(nn.Module):
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
assert quantization is not None
C = marlin_kernels.gptq_marlin_24_gemm(
C = quantization.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.weight_packed,
self.meta,

View File

@ -7,11 +7,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
marlin_kernels = None
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
@ -26,7 +26,7 @@ def _check_marlin_kernels():
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
if quantization is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)