mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
marlin-kernels -> quantization
This commit is contained in:
parent
8aecc59eb0
commit
8ad383c7cb
@ -12,11 +12,11 @@ from text_generation_server.utils.log import log_once
|
|||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
|
|
||||||
class W8A8IntLoader(WeightsLoader):
|
class W8A8IntLoader(WeightsLoader):
|
||||||
@ -163,8 +163,8 @@ class Int8Weight(Weight):
|
|||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
|
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
|
||||||
return W8A8IntLinear(
|
return W8A8IntLinear(
|
||||||
bias=bias,
|
bias=bias,
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
@ -208,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
|
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
|
||||||
input=input,
|
input=input,
|
||||||
scale=None,
|
scale=None,
|
||||||
azp=None,
|
azp=None,
|
||||||
@ -218,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.input_symmetric:
|
if self.input_symmetric:
|
||||||
return marlin_kernels.cutlass_scaled_mm(
|
return quantization.cutlass_scaled_mm(
|
||||||
a=qinput,
|
a=qinput,
|
||||||
b=self.weight,
|
b=self.weight,
|
||||||
scale_a=input_scale,
|
scale_a=input_scale,
|
||||||
@ -233,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
and (self.input_symmetric or input_zero_point is not None)
|
and (self.input_symmetric or input_zero_point is not None)
|
||||||
)
|
)
|
||||||
|
|
||||||
return marlin_kernels.cutlass_scaled_mm_azp(
|
return quantization.cutlass_scaled_mm_azp(
|
||||||
a=qinput,
|
a=qinput,
|
||||||
b=self.weight,
|
b=self.weight,
|
||||||
scale_a=input_scale,
|
scale_a=input_scale,
|
||||||
|
@ -16,11 +16,11 @@ from text_generation_server.utils.weights import (
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: needs to be ported over to MoE and used on CUDA.
|
# TODO: needs to be ported over to MoE and used on CUDA.
|
||||||
@ -33,9 +33,9 @@ quant_dtype: torch.dtype = (
|
|||||||
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
if SYSTEM == "cuda" and quantization is not None:
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
|
CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
|
||||||
major * 10 + minor
|
major * 10 + minor
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -147,9 +147,9 @@ def fp8_quantize(
|
|||||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
be used without modification).
|
be used without modification).
|
||||||
"""
|
"""
|
||||||
if marlin_kernels is not None:
|
if quantization is not None:
|
||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
qweight, scale = quantization.scaled_fp8_quant(
|
||||||
weight.reshape(-1, shape[-1]),
|
weight.reshape(-1, shape[-1]),
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
@ -530,7 +530,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||||
)
|
)
|
||||||
return marlin_kernels.cutlass_scaled_mm(
|
return quantization.cutlass_scaled_mm(
|
||||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,11 +12,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
|
|
||||||
MARLIN_TILE_SIZE = 16
|
MARLIN_TILE_SIZE = 16
|
||||||
@ -36,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
scales = scales.unsqueeze(0)
|
scales = scales.unsqueeze(0)
|
||||||
if scales.shape[1] == 1:
|
if scales.shape[1] == 1:
|
||||||
@ -73,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
A_flat = A.view(-1, A.shape[-1])
|
A_flat = A.view(-1, A.shape[-1])
|
||||||
C = marlin_kernels.fp8_marlin_gemm(
|
C = quantization.fp8_marlin_gemm(
|
||||||
A_flat,
|
A_flat,
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.scales,
|
self.scales,
|
||||||
@ -138,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
|||||||
qweight = pack_fp8_as_int32(weight.t())
|
qweight = pack_fp8_as_int32(weight.t())
|
||||||
|
|
||||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
repacked = quantization.gptq_marlin_repack(
|
||||||
qweight, perm, in_features, out_features, 8
|
qweight, perm, in_features, out_features, 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,11 +17,11 @@ from text_generation_server.utils.log import log_once
|
|||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -41,7 +41,7 @@ def can_use_gptq_marlin(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return (
|
return (
|
||||||
SYSTEM == "cuda"
|
SYSTEM == "cuda"
|
||||||
and marlin_kernels is not None
|
and quantization is not None
|
||||||
and has_sm_8_0
|
and has_sm_8_0
|
||||||
and quantize in {"awq", "gptq"}
|
and quantize in {"awq", "gptq"}
|
||||||
and quant_method in {"awq", "gptq"}
|
and quant_method in {"awq", "gptq"}
|
||||||
@ -291,7 +291,7 @@ def repack_gptq_for_marlin(
|
|||||||
) -> GPTQMarlinWeight:
|
) -> GPTQMarlinWeight:
|
||||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
if bits not in GPTQ_MARLIN_BITS:
|
if bits not in GPTQ_MARLIN_BITS:
|
||||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||||
@ -334,7 +334,7 @@ def repack_gptq_for_marlin(
|
|||||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
|
||||||
if quant_method == "awq":
|
if quant_method == "awq":
|
||||||
repacked = marlin_kernels.awq_marlin_repack(
|
repacked = quantization.awq_marlin_repack(
|
||||||
qweight, in_features, out_features, bits
|
qweight, in_features, out_features, bits
|
||||||
)
|
)
|
||||||
if qzeros is not None:
|
if qzeros is not None:
|
||||||
@ -346,7 +346,7 @@ def repack_gptq_for_marlin(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
repacked = quantization.gptq_marlin_repack(
|
||||||
qweight, perm, in_features, out_features, bits
|
qweight, perm, in_features, out_features, bits
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -383,7 +383,7 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
||||||
out_features = weight.scales.shape[1]
|
out_features = weight.scales.shape[1]
|
||||||
@ -394,14 +394,14 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
|
|
||||||
if weight.qzeros.numel() > 0:
|
if weight.qzeros.numel() > 0:
|
||||||
if weight.bits == 4:
|
if weight.bits == 4:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint4
|
self.quant_type = quantization.scalar_types.uint4
|
||||||
else:
|
else:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint8
|
self.quant_type = quantization.scalar_types.uint8
|
||||||
else:
|
else:
|
||||||
if weight.bits == 4:
|
if weight.bits == 4:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint4b8
|
self.quant_type = quantization.scalar_types.uint4b8
|
||||||
else:
|
else:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint8b128
|
self.quant_type = quantization.scalar_types.uint8b128
|
||||||
|
|
||||||
self.is_full_k = weight.is_full_k
|
self.is_full_k = weight.is_full_k
|
||||||
|
|
||||||
@ -420,10 +420,10 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
A_flat = A.view(-1, A.shape[-1])
|
A_flat = A.view(-1, A.shape[-1])
|
||||||
C = marlin_kernels.gptq_marlin_gemm(
|
C = quantization.gptq_marlin_gemm(
|
||||||
A_flat,
|
A_flat,
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.scales,
|
self.scales,
|
||||||
|
@ -10,11 +10,11 @@ from text_generation_server.utils.kernels import load_kernel
|
|||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
|
|
||||||
class MarlinWeightsLoader(WeightsLoader):
|
class MarlinWeightsLoader(WeightsLoader):
|
||||||
@ -192,7 +192,7 @@ class MarlinLinear(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
||||||
out_features = weight.s.shape[1]
|
out_features = weight.s.shape[1]
|
||||||
@ -221,9 +221,9 @@ class MarlinLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
C = marlin_kernels.marlin_gemm(
|
C = quantization.marlin_gemm(
|
||||||
A.view(-1, A.shape[-1]),
|
A.view(-1, A.shape[-1]),
|
||||||
self.B,
|
self.B,
|
||||||
self.s,
|
self.s,
|
||||||
@ -282,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||||
supported_bits = ", ".join(
|
supported_bits = ", ".join(
|
||||||
@ -309,9 +309,9 @@ class GPTQMarlin24Linear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if weight.bits == 4:
|
if weight.bits == 4:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint4b8
|
self.quant_type = quantization.scalar_types.uint4b8
|
||||||
else:
|
else:
|
||||||
self.quant_type = marlin_kernels.scalar_types.uint8b128
|
self.quant_type = quantization.scalar_types.uint8b128
|
||||||
weights_per_int32 = 32 // weight.bits
|
weights_per_int32 = 32 // weight.bits
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -344,9 +344,9 @@ class GPTQMarlin24Linear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert quantization is not None
|
||||||
|
|
||||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
C = quantization.gptq_marlin_24_gemm(
|
||||||
A.view(-1, A.shape[-1]),
|
A.view(-1, A.shape[-1]),
|
||||||
self.weight_packed,
|
self.weight_packed,
|
||||||
self.meta,
|
self.meta,
|
||||||
|
@ -7,11 +7,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
quantization = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_kernels = None
|
quantization = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
@ -26,7 +26,7 @@ def _check_marlin_kernels():
|
|||||||
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
||||||
)
|
)
|
||||||
|
|
||||||
if marlin_kernels is None:
|
if quantization is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"marlin is not installed, install it with: pip install server/marlin"
|
"marlin is not installed, install it with: pip install server/marlin"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user