From 9914ffe1f195e8dba790d91db6ba931fa4bbf329 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 21 Mar 2025 18:28:58 -0700 Subject: [PATCH] remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A --- .../gaudi/server/tests/utils/test_weights.py | 137 ------ .../server/text_generation_server/cli.py | 6 - .../layers/awq/quantize/hpu.py | 117 ++++- .../layers/compressed_tensors/__init__.py | 3 - .../layers/compressed_tensors/loader.py | 196 -------- .../layers/compressed_tensors/w8a8_int.py | 239 --------- .../layers/compressed_tensors/w8an_fp.py | 168 ------- .../layers/compressed_tensors/wna16_int.py | 188 ------- .../layers/compressed_tensors/wna16_int_24.py | 101 ---- .../layers/gptq/__init__.py | 14 +- .../layers/gptq/{ipex.py => hpu.py} | 306 +++++++----- .../layers/marlin/__init__.py | 15 - .../layers/marlin/fp8.py | 141 ------ .../layers/marlin/gptq.py | 465 ------------------ .../layers/marlin/marlin.py | 359 -------------- .../layers/marlin/util.py | 137 ------ .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 4 - .../flash_pali_gemma_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 2 +- .../flash_santacoder_modeling.py | 4 - .../models/seq2seq_lm.py | 2 +- .../utils/quantization.py | 57 +-- 23 files changed, 291 insertions(+), 2373 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py rename backends/gaudi/server/text_generation_server/layers/gptq/{ipex.py => hpu.py} (57%) delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/__init__.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/fp8.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/gptq.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/marlin.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/util.py diff --git a/backends/gaudi/server/tests/utils/test_weights.py b/backends/gaudi/server/tests/utils/test_weights.py index 556fcea1..c301e50e 100644 --- a/backends/gaudi/server/tests/utils/test_weights.py +++ b/backends/gaudi/server/tests/utils/test_weights.py @@ -7,10 +7,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader -from text_generation_server.layers.marlin.marlin import ( - MarlinWeight, - MarlinWeightsLoader, -) from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path @@ -40,11 +36,6 @@ def gptq_weights_loader_awq(): ) -@pytest.fixture -def marlin_weights_loader(): - return MarlinWeightsLoader(bits=4, is_marlin_24=False) - - dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -125,10 +116,6 @@ dummy_file_system = { "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, - "test_get_weights_col_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - }, "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ @@ -273,18 +260,6 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_weights_row_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, - "test_get_multi_weights_col_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, - "test_get_weights_col_packed_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, } @@ -718,33 +693,6 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_col_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_weights_col( - prefix=prefix, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_weights_col_packed @@ -868,36 +816,6 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_col_packed_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - print(expected_weight) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_multi_weights_col @@ -1004,34 +922,6 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_multi_weights_col_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_weights_row @@ -1148,30 +1038,3 @@ def test_get_weights_row_gptq(gptq_weights_loader): assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" - - -def test_get_weights_row_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_row_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_weights_row( - prefix=prefix, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 24d1d748..e1c0298d 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -16,15 +16,9 @@ app = typer.Typer() class Quantization(str, Enum): - bitsandbytes = "bitsandbytes" - bitsandbytes_nf4 = "bitsandbytes-nf4" - bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" - eetq = "eetq" - exl2 = "exl2" fp8 = "fp8" - marlin = "marlin" class Dtype(str, Enum): diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py index 391371a5..3af0131b 100644 --- a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py @@ -1,19 +1,93 @@ -# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py - from typing import Optional import torch import torch.nn as nn -import awq_inference_engine # with CUDA kernels + +try: + import habana_frameworks.torch.hpu # noqa: F401 + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] -# class ScaledActivation(nn.Module): -# def __init__(self, module, scales): -# super().__init__() -# self.act = module -# self.scales = nn.Parameter(scales.data) -# -# def forward(self, x): -# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift( + qzeros[:, :, None], shifts[None, None, :] + ).to( + torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + else: + izeros = qzeros + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + iweights.shape[-1], + dtype=torch.int32, + device=izeros.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + if izeros is not None: + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + + return iweights, izeros + + +def unpack_weight_and_zeros(qweight, qzeros, bits): + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + return iweight, izeros + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros( + (normal.shape[0], normal.shape[1] // 32 * bits), + dtype=torch.int32, + device=input.device, + ) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q class WQLinear(nn.Module): @@ -38,12 +112,23 @@ class WQLinear(nn.Module): self.qzeros = qzeros self.scales = scales self.bias = bias + self._preprocessing() + + def _preprocessing(self): + device = self.qweight.device + weight, zeros = unpack_weight_and_zeros( + self.qweight.cpu(), self.qzeros.cpu(), self.w_bit + ) + self.qweight = pack_tensor(weight).to(device) + self.qzeros = pack_tensor(zeros).to(device) @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) + x = x.reshape(-1, x.shape[-1]) + weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + outputs = torch.matmul(x, weights) + + outputs = outputs + self.bias if self.bias is not None else outputs + outputs = outputs.reshape(out_shape) + return outputs diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py deleted file mode 100644 index 507af706..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .loader import CompressedTensorsLoader - -__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py deleted file mode 100644 index 17d0224e..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Any, Dict, List, Union - -from compressed_tensors import QuantizationConfig, QuantizationStatus -from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import ( - QuantizationScheme, - QuantizationType, - find_name_or_class_matches, -) -from loguru import logger -from pydantic import ValidationError -from torch import nn - -from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader -from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader -from text_generation_server.layers.compressed_tensors.wna16_int_24 import ( - WNA16Int24Loader, -) -from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import ( - DefaultWeightsLoader, - UnquantizedWeight, - Weights, - WeightsLoader, -) - -# compressed-tensors can match modules as quantization targets. However, -# they need to be objects rather than classes or class names. Since we -# need to match `Linear` targets, make an instance that can be re-used. -_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) - - -class CompressedTensorsLoader(WeightsLoader): - """Loader for checkpoints stored in the compressed-tensors format.""" - - def __init__(self, config: Dict[str, Any]): - quantization_config_raw = config.get("quantization_config") - if quantization_config_raw is None: - # `compression_config` was renamed to `quantization_config`; support - # retained for backward compatibility. - quantization_config_raw = config.get("compression_config") - if quantization_config_raw is None: - raise ValueError( - "Checkpoint does not have compressed-tensors configuration" - ) - - try: - quantization_config = QuantizationConfig.model_validate( - quantization_config_raw - ) - except ValidationError as e: - raise ValueError("Cannot parse compressed-tensors configuration") from e - - if quantization_config.quantization_status not in ( - QuantizationStatus.COMPRESSED, - QuantizationStatus.FROZEN, - ): - raise ValueError( - f"Model quantization was not finished, status was: {quantization_config.quantization_status}" - ) - - self.ignore = ( - quantization_config.ignore if quantization_config.ignore is not None else [] - ) - self.loaders = self._get_target_loaders(quantization_config) - - for target, loader in self.loaders.items(): - log_once( - logger.info, - f"Using {loader} for compressed-tensors target '{target}'", - ) - - def get_weights(self, weights: Weights, prefix: str): - loader = self._lookup_loader(prefix) - return loader.get_weights(weights, prefix) - - def get_weights_col_packed( - self, - weights: "Weights", - prefix: str, - block_sizes: Union[int, List[int]], - ): - loader = self._lookup_loader(prefix) - return loader.get_weights_col_packed(weights, prefix, block_sizes) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - loader = self._lookup_loader(prefixes[0]) - return loader.get_multi_weights_col(weights, prefixes, dim) - - def get_weights_row(self, weights: Weights, prefix: str): - loader = self._lookup_loader(prefix) - return loader.get_weights_row(weights, prefix) - - def _get_target_loaders( - self, quantization_config: QuantizationConfig - ) -> Dict[str, WeightsLoader]: - """ - A compressed-tensors checkpoint can use different quantizations - for different targets. This method returns a dictionary with a - loader per target. - """ - - loaders: Dict[str, WeightsLoader] = {} - - format = quantization_config.format - - for group_name, group in quantization_config.config_groups.items(): - # The group configuration can be a string, but does that ever - # happen in a serialized quantization config? - assert isinstance(group, QuantizationScheme) - - loader = self._create_loader_for_group(format, group_name, group) - - # A quantized parameter group can have multiple targets, add the - # loader for all the targets. - for target in group.targets: - if target in loaders: - raise ValueError( - f"Target '{target} has multiple configured loaders'" - ) - loaders[target] = loader - - return loaders - - def _create_loader_for_group( - self, format: str, group_name: str, group: QuantizationScheme - ) -> WeightsLoader: - """ - Find and create a loader for the group with the given quantization - scheme. - """ - # NOTE: we ignore group.output_activations because we don't support - # output quantization yet. - - input_activations = group.input_activations - weights = group.weights - if ( - format - in { - CompressionFormat.float_quantized.value, - CompressionFormat.naive_quantized.value, - } - and weights is not None - and weights.type == QuantizationType.FLOAT - and weights.num_bits == 8 - ): - # FP W8A8 or W8A16. - return W8ANFpLoader(input_activations=input_activations, weights=weights) - elif ( - format == CompressionFormat.pack_quantized.value - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits in (4, 8) - ): - # INT W4A16 or W8A16 (GPTQ/AWQ-like). - return WNA16IntLoader(weights) - elif ( - format == CompressionFormat.marlin_24.value - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits in (4, 8) - ): - return WNA16Int24Loader(weights) - elif ( - format - in { - CompressionFormat.int_quantized.value, - CompressionFormat.naive_quantized.value, - } - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits == 8 - ): - return W8A8IntLoader(input_args=input_activations, weight_args=weights) - else: - raise ValueError( - f"Group '{group_name}' has unsupported compressed-tensors configurtion" - ) - - def _lookup_loader(self, prefix: str) -> WeightsLoader: - """ - Look up the loader to use for a given parameter name (prefix). - """ - - if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: - return DefaultWeightsLoader(UnquantizedWeight) - - # We currently only handle linear layers, so unconditionally pass - # a `Linear` instance. - targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) - if len(targets) == 0: - raise ValueError( - f"Cannot find compressed-tensors target for prefix: {prefix}" - ) - return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py deleted file mode 100644 index fff0c765..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ /dev/null @@ -1,239 +0,0 @@ -from typing import List, Optional, Union, TypeVar -from dataclasses import dataclass - -from loguru import logger -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationType - -from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - - -quantization = None - - -class W8A8IntLoader(WeightsLoader): - """ - Loader for w8a8 integer compressed-tensors parameters. - """ - - def __init__( - self, - *, - input_args: Optional[QuantizationArgs], - weight_args: QuantizationArgs, - ): - if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8: - raise ValueError( - f"{type(self).__name__} only supports w8a8 int checkpoints" - ) - - if not weight_args.symmetric: - raise ValueError("Checkpoints with asymmetric weights are not supported") - - self.load_weight_scale = not weight_args.dynamic - - if input_args is not None: - self.input_symmetric = input_args.symmetric - - if not input_args.dynamic: - log_once( - logger.warning, - "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", - ) - else: - self.input_symmetric = True - - def __str__(self) -> str: - def scale_to_str(scale): - return "static" if scale else "dynamic" - - def symmetric_to_str(symmetric): - return "symmetric" if symmetric else "asymmetric" - - return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" - - def get_weights(self, weights: "Weights", prefix: str): - w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False - ) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - if weight_scale.numel() > 1: - weight_scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - weight_scale = weight_scale.reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes - ] - shapes = [x.shape for x in w] - - w = torch.cat(w, dim=dim) - - weight_scale = None - if self.load_weight_scale: - weight_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) - for p, shape in zip(prefixes, shapes) - ] - weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_weights_row(self, weights: "Weights", prefix: str): - w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - -OtherT = TypeVar("OtherT") - - -def _get_tensor_or_else( - weights: Weights, prefix: str, other: OtherT -) -> Union[torch.Tensor, OtherT]: - # Even if a checkpoint uses e.g. zero-points, they can be elided: - # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105 - if weights.has_tensor(prefix): - return weights.get_tensor(prefix, to_dtype=False) - else: - return other - - -@dataclass -class Int8Weight(Weight): - input_symmetric: bool - weight: torch.Tensor - weight_scale: Optional[torch.Tensor] - - def get_linear(self, bias: torch.Tensor): - if self.weight_scale is None: - assert quantization is not None - qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight) - return W8A8IntLinear( - bias=bias, - input_symmetric=self.input_symmetric, - weight=qweight, - weight_scale=weight_scale, - ) - else: - return W8A8IntLinear( - bias=bias, - input_symmetric=self.input_symmetric, - weight=self.weight, - weight_scale=self.weight_scale, - ) - - -class W8A8IntLinear(torch.nn.Module): - def __init__( - self, - *, - bias: Optional[torch.Tensor], - input_symmetric: bool, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ): - super().__init__() - - weight_scale = weight_scale.to(torch.float32) - - self.bias = bias - self.input_symmetric = input_symmetric - # cutlass kernels require transposed weights. - self.weight = weight.t() - self.weight_scale = weight_scale - - if input_symmetric: - self.zero_point_adj = None - else: - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp - self.zero_point_adj = self.weight.sum( - dim=0, keepdim=True, dtype=torch.int32 - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - qinput, input_scale, input_zero_point = quantization.scaled_int8_quant( - input=input, - scale=None, - azp=None, - symmetric=self.input_symmetric, - ) - - if self.input_symmetric: - return quantization.cutlass_scaled_mm( - a=qinput, - b=self.weight, - scale_a=input_scale, - scale_b=self.weight_scale, - out_dtype=input.dtype, - bias=self.bias, - ) - else: - assert ( - self.zero_point_adj is not None - and input_scale is not None - and (self.input_symmetric or input_zero_point is not None) - ) - - return quantization.cutlass_scaled_mm_azp( - a=qinput, - b=self.weight, - scale_a=input_scale, - scale_b=self.weight_scale, - out_dtype=input.dtype, - azp_adj=self.zero_point_adj, - azp=input_zero_point, - bias=self.bias, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py deleted file mode 100644 index ed63806e..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import List, Optional, Union - -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationType - -from text_generation_server.layers.fp8 import ( - Fp8Weight, - _load_scalar_or_matrix_scale, -) -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class W8ANFpLoader(WeightsLoader): - """ - Loader for W8A8/W8A16 FP compressed-tensors parameters. - """ - - def __init__( - self, - *, - input_activations: Optional[QuantizationArgs], - weights: QuantizationArgs, - ): - assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 - - # We ignore the `strategy` option which sets the scales to be - # per-tensor, per-channel or per-token. What scales are supported - # is dependent on the kernels used (e.g. cutlass can do tokenwise, - # Torch cannot, and FP8-Marlin does not quantize inputs at all). - # So, instead we try to use the best-possible configuration. - - self.load_weight_scale = not weights.dynamic - self.load_input_scale = ( - input_activations is not None and not input_activations.dynamic - ) - self.force_w8a16 = ( - input_activations is not None and input_activations.num_bits == 16 - ) - - def __str__(self) -> str: - def scale_to_str(scale): - return "static" if scale else "dynamic" - - quantization_type = f"W8A{16 if self.force_w8a16 else 8}" - - return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" - - def get_weights(self, weights: "Weights", prefix: str): - w = weights.get_tensor(f"{prefix}.weight") - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - if weight_scale.numel() > 1: - weight_scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) - if input_scale.numel() > 1: - input_scale = weights.get_packed_sharded( - f"{prefix}.input_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - input_scale = input_scale.reshape(-1).max() - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet - w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes - ] - shapes = [x.shape for x in w] - - # Concat then send to the device - w = torch.cat(w, dim=dim).to(weights.device) - - weight_scale = None - if self.load_weight_scale: - weight_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) - for p, shape in zip(prefixes, shapes) - ] - weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) - - input_scale = None - if self.load_input_scale: - input_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) - for p, shape in zip(prefixes, shapes) - if weights.has_tensor(f"{p}.input_scale") - ] - assert len(input_scale) == 0 or len(input_scale) == len(prefixes) - input_scale = ( - torch.cat(input_scale, dim=0).reshape(-1).max() - if len(input_scale) != 0 - else None - ) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_weights_row(self, weights: "Weights", prefix: str): - w = weights.get_sharded(f"{prefix}.weight", dim=1) - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py deleted file mode 100644 index bb69c6b5..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py +++ /dev/null @@ -1,188 +0,0 @@ -from typing import List, Union - -import torch -from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs -from loguru import logger - -from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class WNA16IntLoader(WeightsLoader): - """ - Loader for W4A16/W8A16 INT compressed-tensors parameters. - """ - - def __init__(self, weights: QuantizationArgs): - self.weights = weights - self.desc_act = self.weights.actorder == ActivationOrdering.GROUP - self.groupsize = ( - -1 if self.weights.group_size is None else self.weights.group_size - ) - - def __str__(self) -> str: - quantization_type = f"W{self.weights.num_bits}A16" - - return f"{self.__class__.__name__} ({quantization_type})" - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t() - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - - zero_point = None - if not self.weights.symmetric: - zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_tensor(f"{prefix}.weight_g_idx") - - scales = weights.get_tensor(f"{prefix}.weight.scales").t() - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - weight_packed = weights.get_packed_sharded( - f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes - ).t() - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - scales = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes - ).t() - scales = scales.to(dtype=weights.dtype) - - zero_point = None - if not self.weights.symmetric: - zero_point = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=0, block_sizes=block_sizes - ).t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - weight_packed = torch.cat( - [ - weights.get_sharded(f"{p}.weight_packed", dim=0).t() - for p in prefixes - ], - dim=1, - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes], - dim=1, - ) - - zero_point = None - if not self.weights.symmetric: - zero_point = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1 - ).t() - - g_idx = None - if self.desc_act: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t() - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - - zero_point = None - if not self.weights.symmetric: - if self.desc_act or self.groupsize == -1: - zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() - else: - zero_point = weights.get_sharded( - f"{prefix}.weight_zero_point", dim=1 - ).t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.weight_scale").t() - else: - scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t() - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=sharded_in_features, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py deleted file mode 100644 index 27b8614c..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import List, Union - -import torch - - -from compressed_tensors.quantization import QuantizationArgs, QuantizationType -from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class WNA16Int24Loader(WeightsLoader): - """ - Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints. - """ - - def __init__(self, weight_args: QuantizationArgs): - super().__init__() - - if weight_args.type != QuantizationType.INT: - raise ValueError( - f"{type(self).__name__} only supports wNa8 int checkpoints" - ) - - if weight_args.strategy == "group" and weight_args.group_size is None: - raise ValueError("`group_size` must be set when `actorder` is `group`") - - self.bits = weight_args.num_bits - self.group_size = weight_args.group_size - - def __str__(self) -> str: - quantization_type = f"W{self.bits}A16 2:4 sparsity" - - return f"{self.__class__.__name__} ({quantization_type})" - - def get_weights(self, weights: Weights, prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - weight_packed = weights.get_tensor(f"{prefix}.weight_packed") - meta = weights.get_tensor(f"{prefix}.meta") - scale_packed = weights.get_tensor(f"{prefix}.scale_packed") - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - weight_packed = weights.get_packed_sharded( - f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes - ) - meta = weights.get_packed_sharded( - f"{prefix}.meta", dim=1, block_sizes=block_sizes - ) - scale_packed = weights.get_packed_sharded( - f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes - ) - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - weight_packed = torch.cat( - [weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1 - ) - meta = torch.cat( - [weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1 - ) - scale_packed = torch.cat( - [weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1 - ) - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0) - meta = weights.get_sharded(f"{prefix}.meta", dim=0) - if self.group_size is None: - scale_packed = weights.get_tensor(f"{prefix}.scale_packed") - else: - scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0) - - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index e62a334c..90b8f692 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -7,7 +7,7 @@ from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -QuantLinear = None +from .hpu import QuantLinear @dataclass @@ -215,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 ) - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - self.bits == 4 - and HAS_EXLLAMA - and self.quantize == "gptq" - and not self.desc_act - ) + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act if self.quantize == "gptq" and self.quant_method == "gptq": w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] @@ -362,6 +355,3 @@ class GPTQWeightsLoader(WeightsLoader): else False ) self.quant_method = "gptq" - - -HAS_EXLLAMA = False diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py similarity index 57% rename from backends/gaudi/server/text_generation_server/layers/gptq/ipex.py rename to backends/gaudi/server/text_generation_server/layers/gptq/hpu.py index 48584e90..25d5c3d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -1,125 +1,181 @@ -import math -import numpy as np -import torch -import torch.nn as nn - -import intel_extension_for_pytorch as ipex - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [4]: - raise NotImplementedError("Only 4 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - self.woq_linear = ( - ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( - self.qweight, - self.scales, - self.qzeros, - self.infeatures, - self.outfeatures, - bias=self.bias, - group_size=self.groupsize, - g_idx=g_idx, - quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, - dtype=ipex.llm.quantization.QuantDtype.INT4, - ) - ) - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [4]: - raise NotImplementedError("Only 4 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [4]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 4 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [4]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 4 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = self.woq_linear(x.reshape(-1, x.shape[-1])) - return out.reshape(out_shape) +import math +import numpy as np +import torch +import torch.nn as nn + +try: + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self._preprocessing() + + def unpack_zeros_from_cuda_old_format(self): + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to( + self.scales.dtype + ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) + return zeros + + def unpack_weight_from_cuda_old_format(self): + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2])) + return weight + + def _preprocessing(self): + self.qweight = self.qweight.cpu() + weight = self.unpack_weight_from_cuda_old_format() + new_qweight = pack_tensor(weight) + self.qweight = new_qweight.to("hpu") + + # TODO: Support group indexing and remove the check + columns = self.qweight.shape[0] + g_idx_trivial = [i // self.group_size for i in range(columns)] + g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32) + assert torch.equal( + self.g_idx, g_idx_trivial + ), "Non-trivial tensor g_idx is not supported" + + zeros = self.unpack_zeros_from_cuda_old_format().cpu() + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to("hpu") + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + out = torch.matmul(x, weight) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py deleted file mode 100644 index 3ff3ed58..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeightsLoader, - can_use_gptq_marlin, - repack_gptq_for_marlin, -) -from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader - -__all__ = [ - "GPTQMarlinFP8Linear", - "GPTQMarlinWeightsLoader", - "MarlinWeightsLoader", - "can_use_gptq_marlin", - "repack_gptq_for_marlin", -] diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py deleted file mode 100644 index c2666d2b..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -from text_generation_server.layers.fp8 import fp8_quantize -from text_generation_server.layers.marlin.gptq import _check_valid_shape -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - permute_scales, -) - - -quantization = None - - -MARLIN_TILE_SIZE = 16 - - -class GPTQMarlinFP8Linear(nn.Module): - """ - FP8 GPTQ-Marlin linear layer. - """ - - def __init__( - self, - qweight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> None: - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - scales = scales.unsqueeze(0) - if scales.shape[1] == 1: - out_features, in_features = qweight.shape - scales = scales.repeat(1, out_features) - qweight, scales = repack_fp8_for_marlin(qweight, scales) - - in_features = qweight.shape[0] * MARLIN_TILE_SIZE - out_features = scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.qweight = qweight - self.scales = scales - self.bias = bias if bias is not None else None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=qweight.device - ) - - @classmethod - def from_unquant(cls, weight, bias, dtype): - qweight, scales = fp8_quantize(weight) - return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) - - @classmethod - def from_fp8( - cls, - weight: torch.Tensor, - scale: torch.Tensor, - bias: torch.Tensor, - dtype: torch.dtype, - **kwargs, - ): - return cls(qweight=weight, scales=scale.to(dtype), bias=bias) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - A_flat = A.view(-1, A.shape[-1]) - C = quantization.fp8_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.workspace, - 8, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements). - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - - if fp8_tensor.shape[0] % 4 != 0: - raise ValueError( - f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" - ) - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = torch.zeros( - fp8_tensor.shape[0] // 4, - fp8_tensor.shape[1], - dtype=torch.int32, - device=fp8_tensor.device, - ) - - for i in range(4): - packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) - - return packed - - -def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): - """ - Repack FP8 tensor for GPTQ-Marlin. - """ - - out_features, in_features = weight.shape - - # Torch linear layers weights with shape [out_features, in_features], - # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], - # so transpose before packing. - qweight = pack_fp8_as_int32(weight.t()) - - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = quantization.gptq_marlin_repack( - qweight, perm, in_features, out_features, 8 - ) - - scales = permute_scales(scales) - - return repacked, scales diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py deleted file mode 100644 index 185a6d77..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py +++ /dev/null @@ -1,465 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - marlin_zero_points, - permute_scales, - unpack_cols, -) -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - - -quantization = None - - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - - -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return False - - -class GPTQMarlinWeightsLoader(WeightsLoader): - """ - Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. - """ - - def __init__( - self, - *, - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - quantize: str, - sym: bool, - ): - self.bits = bits - self.desc_act = desc_act - self.groupsize = groupsize - self.quant_method = quant_method - self.quantize = quantize - self.sym = sym - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - qweight = weights.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - scales = weights.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=weights.dtype) - - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - qweight = torch.cat( - [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) - - def _get_gptq_params(self, weights: Weights): - if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): - self.bits = weights.get_tensor("gptq_bits").item() - self.groupsize = weights.get_tensor("gptq_groupsize").item() - self.desc_act = False - # `server quantize` used asymmetric quantization unconditionally - # before the `gptq_sym` setting tensor was added. - self.sym = ( - weights.get_tensor("gptq_sym").item() - if weights.has_tensor("gptq_sym") - else False - ) - self.quant_method = "gptq" - - -@dataclass -class GPTQMarlinWeight(Weight): - """ - Repacked GPTQ Marlin weights. - """ - - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool - - def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype in (torch.float16, torch.bfloat16) - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlinLinear( - weight=self, - bias=bias, - ) - - -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - qzeros: Optional[torch.Tensor], - scales: torch.Tensor, - g_idx: Optional[torch.Tensor], - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert quantization is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"): - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") - - weights_per_int = 32 // bits - in_features = qweight.shape[0] - out_features = qweight.shape[1] - - # AWQ uses column packing, GPTQ uses row packing - if quant_method == "awq": - out_features *= weights_per_int - else: - in_features *= weights_per_int - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if g_idx is not None and desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - if quant_method == "awq": - repacked = quantization.awq_marlin_repack( - qweight, in_features, out_features, bits - ) - if qzeros is not None: - qzeros = awq_to_marlin_zero_points( - qzeros, - in_features // groupsize, - out_features, - bits, - ) - - else: - repacked = quantization.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - if qzeros is None: - qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - if weight.bits not in (4, 8): - raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization") - - if weight.qzeros.numel() > 0: - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4 - else: - self.quant_type = quantization.scalar_types.uint8 - else: - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4b8 - else: - self.quant_type = quantization.scalar_types.uint8b128 - - self.is_full_k = weight.is_full_k - - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - A_flat = A.view(-1, A.shape[-1]) - C = quantization.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.quant_type, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - True, - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py deleted file mode 100644 index 2ffbcf33..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py +++ /dev/null @@ -1,359 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import torch -import torch.nn as nn - -from text_generation_server.layers.marlin.util import _check_marlin_kernels -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -quantization = None - - -class MarlinWeightsLoader(WeightsLoader): - """Loader for Marlin-quantized weights.""" - - def __init__(self, *, bits: int, is_marlin_24: bool): - self.bits = bits - self.is_marlin_24 = is_marlin_24 - - def get_weights(self, weights: "Weights", prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = weights.get_tensor(f"{prefix}.B_24") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_tensor(f"{prefix}.B_meta") - s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = weights.get_tensor(f"{prefix}.B") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - s = weights.get_tensor(f"{prefix}.s") - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - if self.is_marlin_24: - B = weights.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = weights.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - B = weights.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - if self.is_marlin_24: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_row(self, weights: Weights, prefix: str): - if self.is_marlin_24: - try: - B = weights.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = weights.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - - return weight - - -@dataclass -class MarlinWeight(Weight): - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): bfloat16/float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype in [torch.float16, torch.bfloat16] - - def get_linear(self, bias: torch.Tensor): - return MarlinLinear(weight=self, bias=bias) - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - C = quantization.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_TILE_SIZE = 16 - - -@dataclass -class GPTQMarlin24Weight: - """ - GPTQ-Marlin 2:4 weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - B_meta (torch.Tensor): metadata for 2:4 sparsity. - s (torch.Tensor): float16 scales. - bits: quantized weight size. - """ - - weight_packed: torch.Tensor - meta: torch.Tensor - scale_packed: torch.Tensor - bits: int - - def __post_init__(self): - assert self.weight_packed.dtype == torch.int32 - assert self.meta.dtype == torch.int16 - assert self.scale_packed.dtype == torch.float16 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlin24Linear( - weight=self, - bias=bias, - ) - - -class GPTQMarlin24Linear(nn.Module): - def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: - supported_bits = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS - ) - raise RuntimeError( - f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" - ) - - in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.scale_packed.shape[1] - groupsize = ( - -1 - if weight.scale_packed.shape[0] == 1 - else in_features // weight.scale_packed.shape[0] - ) - - if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - supported_sizes = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - raise RuntimeError( - f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" - ) - - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4b8 - else: - self.quant_type = quantization.scalar_types.uint8b128 - weights_per_int32 = 32 // weight.bits - - assert ( - out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" - assert ( - out_features % weights_per_int32 == 0 - ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" - - assert ( - in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" - if groupsize != -1 and in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisable by group size ({groupsize})" - ) - - self.weight_packed = weight.weight_packed - self.meta = weight.meta - self.scale_packed = weight.scale_packed - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, - dtype=torch.int, - device=weight.weight_packed.device, - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - C = quantization.gptq_marlin_24_gemm( - A.view(-1, A.shape[-1]), - self.weight_packed, - self.meta, - self.scale_packed, - self.workspace, - self.quant_type, - A.shape[0], - self.scale_packed.shape[1], - A.shape[1], - ) - - C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py deleted file mode 100644 index 9f52340f..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/util.py +++ /dev/null @@ -1,137 +0,0 @@ -import functools -from typing import List, Tuple - -import numpy -import torch - - -quantization = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def _check_marlin_kernels(): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if quantization is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -@functools.cache -def get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def permute_scales(scales: torch.Tensor): - scale_perm, scale_perm_single = get_perms() - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - else: - scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - -# Functions below are from vLLM - - -def get_pack_factor(bits: int) -> int: - if 32 % bits != 0: - raise ValueError(f"Cannot {bits} bit values into uint32") - return 32 // bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - scale_perm, _ = get_perms() - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d832fb00..c3e5727b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -398,6 +398,7 @@ class FlashGemmaModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, + adapter_data: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: @@ -479,6 +480,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables, slots, seqlen, + adapter_data, prefill_cache_indices, hpu_attention_meta, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 80236fe8..a7a85d3a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads): prefix, weights, ) - elif config.quantize == "marlin": - raise RuntimeError( - "GPT-2 models with marlin quantization are not yet supported" - ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index af0f8f89..2b67501d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -111,6 +111,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 21c4bc71..cf7c9a79 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 57d4ee64..a41518d7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -32,10 +32,6 @@ def load_multi_mqa( return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) - elif config.quantize == "marlin": - raise RuntimeError( - "santacoder models with marlin quantization are not yet supported" - ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py index 7a63d4dd..0ee6ed16 100644 --- a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -588,7 +588,7 @@ class Seq2SeqLM(Model): aliases=aliases, weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + if config.quantize in ["awq", "gptq"]: weights._set_gptq_params(model_id, revision) model = model_class(config, weights) diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index ee561acc..a8faf4a5 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download -from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( - DefaultWeightsLoader, WeightsLoader, ) @@ -129,64 +127,13 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - if can_use_gptq_marlin( + return GPTQWeightsLoader( bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ): - from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader - - return GPTQMarlinWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - else: - return GPTQWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - elif quantize == "bitsandbytes": - from text_generation_server.layers.bnb import BNBWeight - - return DefaultWeightsLoader(BNBWeight) - elif quantize == "bitsandbytes-fp4": - from text_generation_server.layers.bnb import BNBFP4Weight - - return DefaultWeightsLoader(BNBFP4Weight) - elif quantize == "bitsandbytes-nf4": - from text_generation_server.layers.bnb import BNBNF4Weight - - return DefaultWeightsLoader(BNBNF4Weight) - elif quantize == "eetq": - from text_generation_server.layers.eetq import EETQWeight - - return DefaultWeightsLoader(EETQWeight) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2WeightsLoader - - return Exl2WeightsLoader() - elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeightsLoader - - # TODO: improve check once we have one config type per quantize value - if not isinstance(quantizer_config, _QuantizerConfig): - raise ValueError( - f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." - ) - - return MarlinWeightsLoader( - bits=quantizer_config.bits, - is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader