mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
use DefaultWeightsLoader in skip modules
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
b7bdbbd8c0
commit
e66bbfff2e
@ -10,7 +10,7 @@ from text_generation_server.utils.weights import (
|
|||||||
Weight,
|
Weight,
|
||||||
Weights,
|
Weights,
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
UnquantizedWeight,
|
DefaultWeightsLoader,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
@ -95,7 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
quant_method: str,
|
quant_method: str,
|
||||||
quantize: str,
|
quantize: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
modules_to_not_convert: Optional[List[str]],
|
modules_to_not_convert: List[str],
|
||||||
):
|
):
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
@ -117,8 +117,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
w = weights.get_tensor(f"{prefix}.weight")
|
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||||
return UnquantizedWeight(w)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
@ -200,10 +199,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
w = weights.get_packed_sharded(
|
return DefaultWeightsLoader.get_weights_col_packed(
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
weights, prefix, block_sizes
|
||||||
)
|
)
|
||||||
return UnquantizedWeight(w)
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
@ -256,10 +254,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
w = torch.cat(
|
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||||
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
|
|
||||||
)
|
|
||||||
return UnquantizedWeight(w)
|
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
@ -339,8 +334,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||||
return UnquantizedWeight(w)
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -77,7 +77,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||||
desc_act = data["quantization_config"].get("desc_act", False)
|
desc_act = data["quantization_config"].get("desc_act", False)
|
||||||
modules_to_not_convert = data["quantization_config"].get(
|
modules_to_not_convert = data["quantization_config"].get(
|
||||||
"modules_to_not_convert", None
|
"modules_to_not_convert", []
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
|
Loading…
Reference in New Issue
Block a user