diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 25387682..f8a62cf5 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -6,7 +6,12 @@ import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + DefaultWeightsLoader, +) if SYSTEM == "ipex": from .ipex import QuantLinear @@ -90,6 +95,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -97,6 +103,7 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -109,6 +116,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights(weights, prefix) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -175,12 +185,23 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): + if modules_to_not_convert is None: + return False + return any(module_name in prefix for module_name in modules_to_not_convert) + def get_weights_col_packed( self, weights: Weights, prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes + ) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -232,6 +253,8 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -310,6 +333,8 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_row(weights, prefix) try: qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index e44cf64f..7324b33f 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -21,6 +21,7 @@ class _QuantizerConfig: quant_method: str sym: bool weight_block_size: Optional[List[int]] + modules_to_not_convert: Optional[List[str]] @dataclass @@ -51,6 +52,7 @@ def _get_quantizer_config(model_id, revision): sym = False desc_act = False weight_block_size = None + modules_to_not_convert = None filename = "config.json" try: @@ -73,7 +75,10 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", [] + ) except Exception: filename = "quantize_config.json" try: @@ -110,6 +115,7 @@ def _get_quantizer_config(model_id, revision): sym=sym, desc_act=desc_act, weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) @@ -159,6 +165,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "bitsandbytes": from text_generation_server.layers.bnb import BNBWeight