marlin-kernels -> quantization

This commit is contained in:
Daniël de Kok 2025-02-05 15:10:41 +00:00
parent 8aecc59eb0
commit 8ad383c7cb
6 changed files with 47 additions and 47 deletions

View File

@ -12,11 +12,11 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
class W8A8IntLoader(WeightsLoader): class W8A8IntLoader(WeightsLoader):
@ -163,8 +163,8 @@ class Int8Weight(Weight):
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
assert marlin_kernels is not None assert quantization is not None
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight) qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
return W8A8IntLinear( return W8A8IntLinear(
bias=bias, bias=bias,
input_symmetric=self.input_symmetric, input_symmetric=self.input_symmetric,
@ -208,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant( qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
input=input, input=input,
scale=None, scale=None,
azp=None, azp=None,
@ -218,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
) )
if self.input_symmetric: if self.input_symmetric:
return marlin_kernels.cutlass_scaled_mm( return quantization.cutlass_scaled_mm(
a=qinput, a=qinput,
b=self.weight, b=self.weight,
scale_a=input_scale, scale_a=input_scale,
@ -233,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
and (self.input_symmetric or input_zero_point is not None) and (self.input_symmetric or input_zero_point is not None)
) )
return marlin_kernels.cutlass_scaled_mm_azp( return quantization.cutlass_scaled_mm_azp(
a=qinput, a=qinput,
b=self.weight, b=self.weight,
scale_a=input_scale, scale_a=input_scale,

View File

@ -16,11 +16,11 @@ from text_generation_server.utils.weights import (
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
try: try:
# TODO: needs to be ported over to MoE and used on CUDA. # TODO: needs to be ported over to MoE and used on CUDA.
@ -33,9 +33,9 @@ quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
) )
if SYSTEM == "cuda" and marlin_kernels is not None: if SYSTEM == "cuda" and quantization is not None:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8( CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
major * 10 + minor major * 10 + minor
) )
else: else:
@ -147,9 +147,9 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification). be used without modification).
""" """
if marlin_kernels is not None: if quantization is not None:
shape = weight.shape shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant( qweight, scale = quantization.scaled_fp8_quant(
weight.reshape(-1, shape[-1]), weight.reshape(-1, shape[-1]),
scale=scale, scale=scale,
scale_ub=scale_upper_bound, scale_ub=scale_upper_bound,
@ -530,7 +530,7 @@ class Fp8Linear(torch.nn.Module):
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False input, scale_upper_bound=self.scale_upper_bound, scalar=False
) )
return marlin_kernels.cutlass_scaled_mm( return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
) )

View File

@ -12,11 +12,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
@ -36,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.shape[1] == 1:
@ -73,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
A_flat = A.view(-1, A.shape[-1]) A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm( C = quantization.fp8_marlin_gemm(
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,
@ -138,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
qweight = pack_fp8_as_int32(weight.t()) qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device) perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack( repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8 qweight, perm, in_features, out_features, 8
) )

View File

@ -17,11 +17,11 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
try: try:
@ -41,7 +41,7 @@ def can_use_gptq_marlin(
) -> bool: ) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and quantization is not None
and has_sm_8_0 and has_sm_8_0
and quantize in {"awq", "gptq"} and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"} and quant_method in {"awq", "gptq"}
@ -291,7 +291,7 @@ def repack_gptq_for_marlin(
) -> GPTQMarlinWeight: ) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
if bits not in GPTQ_MARLIN_BITS: if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
@ -334,7 +334,7 @@ def repack_gptq_for_marlin(
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq": if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack( repacked = quantization.awq_marlin_repack(
qweight, in_features, out_features, bits qweight, in_features, out_features, bits
) )
if qzeros is not None: if qzeros is not None:
@ -346,7 +346,7 @@ def repack_gptq_for_marlin(
) )
else: else:
repacked = marlin_kernels.gptq_marlin_repack( repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits qweight, perm, in_features, out_features, bits
) )
@ -383,7 +383,7 @@ class GPTQMarlinLinear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1] out_features = weight.scales.shape[1]
@ -394,14 +394,14 @@ class GPTQMarlinLinear(nn.Module):
if weight.qzeros.numel() > 0: if weight.qzeros.numel() > 0:
if weight.bits == 4: if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4 self.quant_type = quantization.scalar_types.uint4
else: else:
self.quant_type = marlin_kernels.scalar_types.uint8 self.quant_type = quantization.scalar_types.uint8
else: else:
if weight.bits == 4: if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4b8 self.quant_type = quantization.scalar_types.uint4b8
else: else:
self.quant_type = marlin_kernels.scalar_types.uint8b128 self.quant_type = quantization.scalar_types.uint8b128
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
@ -420,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
A_flat = A.view(-1, A.shape[-1]) A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm( C = quantization.gptq_marlin_gemm(
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,

View File

@ -10,11 +10,11 @@ from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
class MarlinWeightsLoader(WeightsLoader): class MarlinWeightsLoader(WeightsLoader):
@ -192,7 +192,7 @@ class MarlinLinear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1] out_features = weight.s.shape[1]
@ -221,9 +221,9 @@ class MarlinLinear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
C = marlin_kernels.marlin_gemm( C = quantization.marlin_gemm(
A.view(-1, A.shape[-1]), A.view(-1, A.shape[-1]),
self.B, self.B,
self.s, self.s,
@ -282,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
super().__init__() super().__init__()
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert quantization is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join( supported_bits = ", ".join(
@ -309,9 +309,9 @@ class GPTQMarlin24Linear(nn.Module):
) )
if weight.bits == 4: if weight.bits == 4:
self.quant_type = marlin_kernels.scalar_types.uint4b8 self.quant_type = quantization.scalar_types.uint4b8
else: else:
self.quant_type = marlin_kernels.scalar_types.uint8b128 self.quant_type = quantization.scalar_types.uint8b128
weights_per_int32 = 32 // weight.bits weights_per_int32 = 32 // weight.bits
assert ( assert (
@ -344,9 +344,9 @@ class GPTQMarlin24Linear(nn.Module):
) )
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert quantization is not None
C = marlin_kernels.gptq_marlin_24_gemm( C = quantization.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]), A.view(-1, A.shape[-1]),
self.weight_packed, self.weight_packed,
self.meta, self.meta,

View File

@ -7,11 +7,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "cuda": if SYSTEM == "cuda":
marlin_kernels = load_kernel( quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
else: else:
marlin_kernels = None quantization = None
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
@ -26,7 +26,7 @@ def _check_marlin_kernels():
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
) )
if marlin_kernels is None: if quantization is None:
raise NotImplementedError( raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin" "marlin is not installed, install it with: pip install server/marlin"
) )