Add modules_to_not_convert in quantized model (#3053)

* fix modules_to_not_convert

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tp quant skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert unquantized changes

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* use DefaultWeightsLoader in skip modules

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-03-10 22:03:51 +08:00 committed by GitHub
parent bbe218a4f7
commit cae0cbe87d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 2 deletions

View File

@ -6,7 +6,12 @@ import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
if SYSTEM == "ipex": if SYSTEM == "ipex":
from .ipex import QuantLinear from .ipex import QuantLinear
@ -90,6 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
modules_to_not_convert: List[str],
): ):
self.bits = bits self.bits = bits
self.desc_act = desc_act self.desc_act = desc_act
@ -97,6 +103,7 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method self.quant_method = quant_method
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights) self._get_gptq_params(weights)
@ -109,6 +116,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -175,12 +185,23 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama=use_exllama, use_exllama=use_exllama,
) )
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
if modules_to_not_convert is None:
return False
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -232,6 +253,8 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -310,6 +333,8 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
try: try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:

View File

@ -21,6 +21,7 @@ class _QuantizerConfig:
quant_method: str quant_method: str
sym: bool sym: bool
weight_block_size: Optional[List[int]] weight_block_size: Optional[List[int]]
modules_to_not_convert: Optional[List[str]]
@dataclass @dataclass
@ -51,6 +52,7 @@ def _get_quantizer_config(model_id, revision):
sym = False sym = False
desc_act = False desc_act = False
weight_block_size = None weight_block_size = None
modules_to_not_convert = None
filename = "config.json" filename = "config.json"
try: try:
@ -73,7 +75,10 @@ def _get_quantizer_config(model_id, revision):
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
@ -110,6 +115,7 @@ def _get_quantizer_config(model_id, revision):
sym=sym, sym=sym,
desc_act=desc_act, desc_act=desc_act,
weight_block_size=weight_block_size, weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
) )
@ -159,6 +165,7 @@ def get_loader(
quant_method=quantizer_config.quant_method, quant_method=quantizer_config.quant_method,
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
) )
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight from text_generation_server.layers.bnb import BNBWeight