From 8ad383c7cb64dd1f8f0ebbb9376d4e5f4743e301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Feb 2025 15:10:41 +0000 Subject: [PATCH] marlin-kernels -> quantization --- .../layers/compressed_tensors/w8a8_int.py | 16 ++++++------ server/text_generation_server/layers/fp8.py | 14 +++++----- .../layers/marlin/fp8.py | 12 ++++----- .../layers/marlin/gptq.py | 26 +++++++++---------- .../layers/marlin/marlin.py | 20 +++++++------- .../layers/marlin/util.py | 6 ++--- 6 files changed, 47 insertions(+), 47 deletions(-) diff --git a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py index e9e3e975..b66057ec 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -12,11 +12,11 @@ from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None class W8A8IntLoader(WeightsLoader): @@ -163,8 +163,8 @@ class Int8Weight(Weight): def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: - assert marlin_kernels is not None - qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight) + assert quantization is not None + qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight) return W8A8IntLinear( bias=bias, input_symmetric=self.input_symmetric, @@ -208,9 +208,9 @@ class W8A8IntLinear(torch.nn.Module): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None - qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant( + qinput, input_scale, input_zero_point = quantization.scaled_int8_quant( input=input, scale=None, azp=None, @@ -218,7 +218,7 @@ class W8A8IntLinear(torch.nn.Module): ) if self.input_symmetric: - return marlin_kernels.cutlass_scaled_mm( + return quantization.cutlass_scaled_mm( a=qinput, b=self.weight, scale_a=input_scale, @@ -233,7 +233,7 @@ class W8A8IntLinear(torch.nn.Module): and (self.input_symmetric or input_zero_point is not None) ) - return marlin_kernels.cutlass_scaled_mm_azp( + return quantization.cutlass_scaled_mm_azp( a=qinput, b=self.weight, scale_a=input_scale, diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index d412c5a4..fe138a4a 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -16,11 +16,11 @@ from text_generation_server.utils.weights import ( from text_generation_server.utils.log import log_once if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None try: # TODO: needs to be ported over to MoE and used on CUDA. @@ -33,9 +33,9 @@ quant_dtype: torch.dtype = ( torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn ) -if SYSTEM == "cuda" and marlin_kernels is not None: +if SYSTEM == "cuda" and quantization is not None: major, minor = torch.cuda.get_device_capability() - CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8( + CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8( major * 10 + minor ) else: @@ -147,9 +147,9 @@ def fp8_quantize( argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can be used without modification). """ - if marlin_kernels is not None: + if quantization is not None: shape = weight.shape - qweight, scale = marlin_kernels.scaled_fp8_quant( + qweight, scale = quantization.scaled_fp8_quant( weight.reshape(-1, shape[-1]), scale=scale, scale_ub=scale_upper_bound, @@ -530,7 +530,7 @@ class Fp8Linear(torch.nn.Module): qinput, scale = fp8_quantize( input, scale_upper_bound=self.scale_upper_bound, scalar=False ) - return marlin_kernels.cutlass_scaled_mm( + return quantization.cutlass_scaled_mm( qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias ) diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 48f5289f..10751a05 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -12,11 +12,11 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None MARLIN_TILE_SIZE = 16 @@ -36,7 +36,7 @@ class GPTQMarlinFP8Linear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None scales = scales.unsqueeze(0) if scales.shape[1] == 1: @@ -73,10 +73,10 @@ class GPTQMarlinFP8Linear(nn.Module): return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.fp8_marlin_gemm( + C = quantization.fp8_marlin_gemm( A_flat, self.qweight, self.scales, @@ -138,7 +138,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): qweight = pack_fp8_as_int32(weight.t()) perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( + repacked = quantization.gptq_marlin_repack( qweight, perm, in_features, out_features, 8 ) diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 066724e2..e85c8333 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -17,11 +17,11 @@ from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None try: @@ -41,7 +41,7 @@ def can_use_gptq_marlin( ) -> bool: return ( SYSTEM == "cuda" - and marlin_kernels is not None + and quantization is not None and has_sm_8_0 and quantize in {"awq", "gptq"} and quant_method in {"awq", "gptq"} @@ -291,7 +291,7 @@ def repack_gptq_for_marlin( ) -> GPTQMarlinWeight: """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None if bits not in GPTQ_MARLIN_BITS: supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) @@ -334,7 +334,7 @@ def repack_gptq_for_marlin( g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) if quant_method == "awq": - repacked = marlin_kernels.awq_marlin_repack( + repacked = quantization.awq_marlin_repack( qweight, in_features, out_features, bits ) if qzeros is not None: @@ -346,7 +346,7 @@ def repack_gptq_for_marlin( ) else: - repacked = marlin_kernels.gptq_marlin_repack( + repacked = quantization.gptq_marlin_repack( qweight, perm, in_features, out_features, bits ) @@ -383,7 +383,7 @@ class GPTQMarlinLinear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE out_features = weight.scales.shape[1] @@ -394,14 +394,14 @@ class GPTQMarlinLinear(nn.Module): if weight.qzeros.numel() > 0: if weight.bits == 4: - self.quant_type = marlin_kernels.scalar_types.uint4 + self.quant_type = quantization.scalar_types.uint4 else: - self.quant_type = marlin_kernels.scalar_types.uint8 + self.quant_type = quantization.scalar_types.uint8 else: if weight.bits == 4: - self.quant_type = marlin_kernels.scalar_types.uint4b8 + self.quant_type = quantization.scalar_types.uint4b8 else: - self.quant_type = marlin_kernels.scalar_types.uint8b128 + self.quant_type = quantization.scalar_types.uint8b128 self.is_full_k = weight.is_full_k @@ -420,10 +420,10 @@ class GPTQMarlinLinear(nn.Module): ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( + C = quantization.gptq_marlin_gemm( A_flat, self.qweight, self.scales, diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index f01f6af2..48aedc72 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -10,11 +10,11 @@ from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None class MarlinWeightsLoader(WeightsLoader): @@ -192,7 +192,7 @@ class MarlinLinear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None in_features = weight.B.shape[0] * MARLIN_TILE_SIZE out_features = weight.s.shape[1] @@ -221,9 +221,9 @@ class MarlinLinear(nn.Module): ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None - C = marlin_kernels.marlin_gemm( + C = quantization.marlin_gemm( A.view(-1, A.shape[-1]), self.B, self.s, @@ -282,7 +282,7 @@ class GPTQMarlin24Linear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: supported_bits = ", ".join( @@ -309,9 +309,9 @@ class GPTQMarlin24Linear(nn.Module): ) if weight.bits == 4: - self.quant_type = marlin_kernels.scalar_types.uint4b8 + self.quant_type = quantization.scalar_types.uint4b8 else: - self.quant_type = marlin_kernels.scalar_types.uint8b128 + self.quant_type = quantization.scalar_types.uint8b128 weights_per_int32 = 32 // weight.bits assert ( @@ -344,9 +344,9 @@ class GPTQMarlin24Linear(nn.Module): ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None - C = marlin_kernels.gptq_marlin_24_gemm( + C = quantization.gptq_marlin_24_gemm( A.view(-1, A.shape[-1]), self.weight_packed, self.meta, diff --git a/server/text_generation_server/layers/marlin/util.py b/server/text_generation_server/layers/marlin/util.py index e73b5397..0c5d715f 100644 --- a/server/text_generation_server/layers/marlin/util.py +++ b/server/text_generation_server/layers/marlin/util.py @@ -7,11 +7,11 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel if SYSTEM == "cuda": - marlin_kernels = load_kernel( + quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: - marlin_kernels = None + quantization = None try: major, _minor = torch.cuda.get_device_capability() @@ -26,7 +26,7 @@ def _check_marlin_kernels(): "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." ) - if marlin_kernels is None: + if quantization is None: raise NotImplementedError( "marlin is not installed, install it with: pip install server/marlin" )