mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
marlin-kernels -> quantization
This commit is contained in:
parent
8aecc59eb0
commit
8ad383c7cb
@ -12,11 +12,11 @@ from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
|
||||
class W8A8IntLoader(WeightsLoader):
|
||||
@ -163,8 +163,8 @@ class Int8Weight(Weight):
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
assert marlin_kernels is not None
|
||||
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
|
||||
assert quantization is not None
|
||||
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_symmetric=self.input_symmetric,
|
||||
@ -208,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
|
||||
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
|
||||
input=input,
|
||||
scale=None,
|
||||
azp=None,
|
||||
@ -218,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
)
|
||||
|
||||
if self.input_symmetric:
|
||||
return marlin_kernels.cutlass_scaled_mm(
|
||||
return quantization.cutlass_scaled_mm(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
@ -233,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
and (self.input_symmetric or input_zero_point is not None)
|
||||
)
|
||||
|
||||
return marlin_kernels.cutlass_scaled_mm_azp(
|
||||
return quantization.cutlass_scaled_mm_azp(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
|
@ -16,11 +16,11 @@ from text_generation_server.utils.weights import (
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
try:
|
||||
# TODO: needs to be ported over to MoE and used on CUDA.
|
||||
@ -33,9 +33,9 @@ quant_dtype: torch.dtype = (
|
||||
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||
if SYSTEM == "cuda" and quantization is not None:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
|
||||
CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
|
||||
major * 10 + minor
|
||||
)
|
||||
else:
|
||||
@ -147,9 +147,9 @@ def fp8_quantize(
|
||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||
be used without modification).
|
||||
"""
|
||||
if marlin_kernels is not None:
|
||||
if quantization is not None:
|
||||
shape = weight.shape
|
||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||
qweight, scale = quantization.scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
@ -530,7 +530,7 @@ class Fp8Linear(torch.nn.Module):
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||
)
|
||||
return marlin_kernels.cutlass_scaled_mm(
|
||||
return quantization.cutlass_scaled_mm(
|
||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||
)
|
||||
|
||||
|
@ -12,11 +12,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
|
||||
MARLIN_TILE_SIZE = 16
|
||||
@ -36,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
scales = scales.unsqueeze(0)
|
||||
if scales.shape[1] == 1:
|
||||
@ -73,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.fp8_marlin_gemm(
|
||||
C = quantization.fp8_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
@ -138,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
||||
qweight = pack_fp8_as_int32(weight.t())
|
||||
|
||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, 8
|
||||
)
|
||||
|
||||
|
@ -17,11 +17,11 @@ from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
|
||||
try:
|
||||
@ -41,7 +41,7 @@ def can_use_gptq_marlin(
|
||||
) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and quantization is not None
|
||||
and has_sm_8_0
|
||||
and quantize in {"awq", "gptq"}
|
||||
and quant_method in {"awq", "gptq"}
|
||||
@ -291,7 +291,7 @@ def repack_gptq_for_marlin(
|
||||
) -> GPTQMarlinWeight:
|
||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
if bits not in GPTQ_MARLIN_BITS:
|
||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||
@ -334,7 +334,7 @@ def repack_gptq_for_marlin(
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
|
||||
if quant_method == "awq":
|
||||
repacked = marlin_kernels.awq_marlin_repack(
|
||||
repacked = quantization.awq_marlin_repack(
|
||||
qweight, in_features, out_features, bits
|
||||
)
|
||||
if qzeros is not None:
|
||||
@ -346,7 +346,7 @@ def repack_gptq_for_marlin(
|
||||
)
|
||||
|
||||
else:
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, bits
|
||||
)
|
||||
|
||||
@ -383,7 +383,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.scales.shape[1]
|
||||
@ -394,14 +394,14 @@ class GPTQMarlinLinear(nn.Module):
|
||||
|
||||
if weight.qzeros.numel() > 0:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint4
|
||||
self.quant_type = quantization.scalar_types.uint4
|
||||
else:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint8
|
||||
self.quant_type = quantization.scalar_types.uint8
|
||||
else:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint4b8
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint8b128
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
@ -420,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.gptq_marlin_gemm(
|
||||
C = quantization.gptq_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
|
@ -10,11 +10,11 @@ from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
|
||||
class MarlinWeightsLoader(WeightsLoader):
|
||||
@ -192,7 +192,7 @@ class MarlinLinear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.s.shape[1]
|
||||
@ -221,9 +221,9 @@ class MarlinLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
C = marlin_kernels.marlin_gemm(
|
||||
C = quantization.marlin_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.s,
|
||||
@ -282,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
supported_bits = ", ".join(
|
||||
@ -309,9 +309,9 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
)
|
||||
|
||||
if weight.bits == 4:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint4b8
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = marlin_kernels.scalar_types.uint8b128
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
weights_per_int32 = 32 // weight.bits
|
||||
|
||||
assert (
|
||||
@ -344,9 +344,9 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||
C = quantization.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.weight_packed,
|
||||
self.meta,
|
||||
|
@ -7,11 +7,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
quantization = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
else:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
@ -26,7 +26,7 @@ def _check_marlin_kernels():
|
||||
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
||||
)
|
||||
|
||||
if marlin_kernels is None:
|
||||
if quantization is None:
|
||||
raise NotImplementedError(
|
||||
"marlin is not installed, install it with: pip install server/marlin"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user