From e66bbfff2e50272e9fd0857a142048c175a3aa1f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Feb 2025 10:35:13 +0000 Subject: [PATCH] use DefaultWeightsLoader in skip modules Signed-off-by: jiqing-feng --- .../layers/gptq/__init__.py | 20 +++++++------------ .../utils/quantization.py | 2 +- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index b2371967..f8a62cf5 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -10,7 +10,7 @@ from text_generation_server.utils.weights import ( Weight, Weights, WeightsLoader, - UnquantizedWeight, + DefaultWeightsLoader, ) if SYSTEM == "ipex": @@ -95,7 +95,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, - modules_to_not_convert: Optional[List[str]], + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -117,8 +117,7 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama = False if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): - w = weights.get_tensor(f"{prefix}.weight") - return UnquantizedWeight(w) + return DefaultWeightsLoader.get_weights(weights, prefix) try: qweight = weights.get_tensor(f"{prefix}.qweight") @@ -200,10 +199,9 @@ class GPTQWeightsLoader(WeightsLoader): block_sizes: Union[int, List[int]], ): if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes ) - return UnquantizedWeight(w) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -256,10 +254,7 @@ class GPTQWeightsLoader(WeightsLoader): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): - w = torch.cat( - [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim - ) - return UnquantizedWeight(w) + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -339,8 +334,7 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama = False if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): - w = weights.get_sharded(f"{prefix}.weight", dim=1) - return UnquantizedWeight(w) + return DefaultWeightsLoader.get_weights_row(weights, prefix) try: qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 8a62deec..7324b33f 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = data["quantization_config"].get("checkpoint_format") desc_act = data["quantization_config"].get("desc_act", False) modules_to_not_convert = data["quantization_config"].get( - "modules_to_not_convert", None + "modules_to_not_convert", [] ) except Exception: filename = "quantize_config.json"