use DefaultWeightsLoader in skip modules

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-02-28 10:35:13 +00:00
parent b7bdbbd8c0
commit e66bbfff2e
2 changed files with 8 additions and 14 deletions

View File

@ -10,7 +10,7 @@ from text_generation_server.utils.weights import (
Weight, Weight,
Weights, Weights,
WeightsLoader, WeightsLoader,
UnquantizedWeight, DefaultWeightsLoader,
) )
if SYSTEM == "ipex": if SYSTEM == "ipex":
@ -95,7 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
modules_to_not_convert: Optional[List[str]], modules_to_not_convert: List[str],
): ):
self.bits = bits self.bits = bits
self.desc_act = desc_act self.desc_act = desc_act
@ -117,8 +117,7 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_tensor(f"{prefix}.weight") return DefaultWeightsLoader.get_weights(weights, prefix)
return UnquantizedWeight(w)
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
@ -200,10 +199,9 @@ class GPTQWeightsLoader(WeightsLoader):
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_packed_sharded( return DefaultWeightsLoader.get_weights_col_packed(
f"{prefix}.weight", dim=0, block_sizes=block_sizes weights, prefix, block_sizes
) )
return UnquantizedWeight(w)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -256,10 +254,7 @@ class GPTQWeightsLoader(WeightsLoader):
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
w = torch.cat( return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
)
return UnquantizedWeight(w)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -339,8 +334,7 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_sharded(f"{prefix}.weight", dim=1) return DefaultWeightsLoader.get_weights_row(weights, prefix)
return UnquantizedWeight(w)
try: try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:

View File

@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"].get("desc_act", False) desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get( modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", None "modules_to_not_convert", []
) )
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"