diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 7e8380352..3526530d1 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -6,7 +6,7 @@ import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader, UnquantizedWeight if SYSTEM == "ipex": from .ipex import QuantLinear @@ -90,6 +90,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: Optional[List[str]], ): self.bits = bits self.desc_act = desc_act @@ -97,6 +98,7 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -109,6 +111,10 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + w = weights.get_tensor(f"{prefix}.weight") + return UnquantizedWeight(w) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -171,9 +177,15 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) + def is_layer_skipped_quantization(self, prefix: str, modules_to_not_convert: List[str]): + if modules_to_not_convert is None: + return False + return any(module_name in prefix for module_name in modules_to_not_convert) + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 772142861..007f99d05 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -85,6 +85,8 @@ class UnquantizedSparseMoELayer(nn.Module): use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, ) return fused_moe( x, diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index e44cf64fe..284726a37 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -21,6 +21,7 @@ class _QuantizerConfig: quant_method: str sym: bool weight_block_size: Optional[List[int]] + modules_to_not_convert: Optional[List[str]] @dataclass @@ -51,6 +52,7 @@ def _get_quantizer_config(model_id, revision): sym = False desc_act = False weight_block_size = None + modules_to_not_convert = None filename = "config.json" try: @@ -73,7 +75,8 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get("modules_to_not_convert", None) except Exception: filename = "quantize_config.json" try: @@ -110,6 +113,7 @@ def _get_quantizer_config(model_id, revision): sym=sym, desc_act=desc_act, weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) @@ -159,6 +163,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "bitsandbytes": from text_generation_server.layers.bnb import BNBWeight