mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
use DefaultWeightsLoader in skip modules
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
b7bdbbd8c0
commit
e66bbfff2e
@ -10,7 +10,7 @@ from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
UnquantizedWeight,
|
||||
DefaultWeightsLoader,
|
||||
)
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
@ -95,7 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
modules_to_not_convert: List[str],
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
@ -117,8 +117,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
return UnquantizedWeight(w)
|
||||
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
@ -200,10 +199,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
return DefaultWeightsLoader.get_weights_col_packed(
|
||||
weights, prefix, block_sizes
|
||||
)
|
||||
return UnquantizedWeight(w)
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -256,10 +254,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
w = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
|
||||
)
|
||||
return UnquantizedWeight(w)
|
||||
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
@ -339,8 +334,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return UnquantizedWeight(w)
|
||||
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
|
@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
desc_act = data["quantization_config"].get("desc_act", False)
|
||||
modules_to_not_convert = data["quantization_config"].get(
|
||||
"modules_to_not_convert", None
|
||||
"modules_to_not_convert", []
|
||||
)
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
|
Loading…
Reference in New Issue
Block a user