diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 3526530d..d443f94a 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -6,7 +6,12 @@ import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader, UnquantizedWeight +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + UnquantizedWeight, +) if SYSTEM == "ipex": from .ipex import QuantLinear @@ -181,7 +186,9 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) - def is_layer_skipped_quantization(self, prefix: str, modules_to_not_convert: List[str]): + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): if modules_to_not_convert is None: return False return any(module_name in prefix for module_name in modules_to_not_convert) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 284726a3..8a62deec 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -76,7 +76,9 @@ def _get_quantizer_config(model_id, revision): quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") desc_act = data["quantization_config"].get("desc_act", False) - modules_to_not_convert = data["quantization_config"].get("modules_to_not_convert", None) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", None + ) except Exception: filename = "quantize_config.json" try: