mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix modules_to_not_convert
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
06dfe9abfe
commit
0bad926fb8
@ -6,7 +6,7 @@ import torch
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader, UnquantizedWeight
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from .ipex import QuantLinear
|
||||
@ -90,6 +90,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
@ -97,6 +98,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
self._get_gptq_params(weights)
|
||||
@ -109,6 +111,10 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
@ -171,9 +177,15 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def is_layer_skipped_quantization(self, prefix: str, modules_to_not_convert: List[str]):
|
||||
if modules_to_not_convert is None:
|
||||
return False
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
|
@ -85,6 +85,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
return fused_moe(
|
||||
x,
|
||||
|
@ -21,6 +21,7 @@ class _QuantizerConfig:
|
||||
quant_method: str
|
||||
sym: bool
|
||||
weight_block_size: Optional[List[int]]
|
||||
modules_to_not_convert: Optional[List[str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -51,6 +52,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
sym = False
|
||||
desc_act = False
|
||||
weight_block_size = None
|
||||
modules_to_not_convert = None
|
||||
|
||||
filename = "config.json"
|
||||
try:
|
||||
@ -73,7 +75,8 @@ def _get_quantizer_config(model_id, revision):
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
desc_act = data["quantization_config"].get("desc_act", False)
|
||||
modules_to_not_convert = data["quantization_config"].get("modules_to_not_convert", None)
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
@ -110,6 +113,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
sym=sym,
|
||||
desc_act=desc_act,
|
||||
weight_block_size=weight_block_size,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
)
|
||||
|
||||
|
||||
@ -159,6 +163,7 @@ def get_loader(
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
modules_to_not_convert=quantizer_config.modules_to_not_convert,
|
||||
)
|
||||
elif quantize == "bitsandbytes":
|
||||
from text_generation_server.layers.bnb import BNBWeight
|
||||
|
Loading…
Reference in New Issue
Block a user