use DefaultWeightsLoader in skip modules

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-02-28 10:35:13 +00:00
parent b7bdbbd8c0
commit e66bbfff2e
2 changed files with 8 additions and 14 deletions

View File

@ -10,7 +10,7 @@ from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
UnquantizedWeight,
DefaultWeightsLoader,
)
if SYSTEM == "ipex":
@ -95,7 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str,
quantize: str,
sym: bool,
modules_to_not_convert: Optional[List[str]],
modules_to_not_convert: List[str],
):
self.bits = bits
self.desc_act = desc_act
@ -117,8 +117,7 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_tensor(f"{prefix}.weight")
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_weights(weights, prefix)
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
@ -200,10 +199,9 @@ class GPTQWeightsLoader(WeightsLoader):
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
return UnquantizedWeight(w)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -256,10 +254,7 @@ class GPTQWeightsLoader(WeightsLoader):
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
w = torch.cat(
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
)
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -339,8 +334,7 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
return UnquantizedWeight(w)
return DefaultWeightsLoader.get_weights_row(weights, prefix)
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:

View File

@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", None
"modules_to_not_convert", []
)
except Exception:
filename = "quantize_config.json"