mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Add modules_to_not_convert in quantized model (#3053)
* fix modules_to_not_convert Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tp quant skip Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert unquantized changes Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use DefaultWeightsLoader in skip modules Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
bbe218a4f7
commit
cae0cbe87d
@ -6,7 +6,12 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
Weights,
|
||||||
|
WeightsLoader,
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from .ipex import QuantLinear
|
from .ipex import QuantLinear
|
||||||
@ -90,6 +95,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
quant_method: str,
|
quant_method: str,
|
||||||
quantize: str,
|
quantize: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
|
modules_to_not_convert: List[str],
|
||||||
):
|
):
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
@ -97,6 +103,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
@ -109,6 +116,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -175,12 +185,23 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
use_exllama=use_exllama,
|
use_exllama=use_exllama,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_layer_skipped_quantization(
|
||||||
|
self, prefix: str, modules_to_not_convert: List[str]
|
||||||
|
):
|
||||||
|
if modules_to_not_convert is None:
|
||||||
|
return False
|
||||||
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights_col_packed(
|
||||||
|
weights, prefix, block_sizes
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
@ -232,6 +253,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
@ -310,6 +333,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -21,6 +21,7 @@ class _QuantizerConfig:
|
|||||||
quant_method: str
|
quant_method: str
|
||||||
sym: bool
|
sym: bool
|
||||||
weight_block_size: Optional[List[int]]
|
weight_block_size: Optional[List[int]]
|
||||||
|
modules_to_not_convert: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -51,6 +52,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
sym = False
|
sym = False
|
||||||
desc_act = False
|
desc_act = False
|
||||||
weight_block_size = None
|
weight_block_size = None
|
||||||
|
modules_to_not_convert = None
|
||||||
|
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
try:
|
try:
|
||||||
@ -73,7 +75,10 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
quant_method = data["quantization_config"]["quant_method"]
|
quant_method = data["quantization_config"]["quant_method"]
|
||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||||
desc_act = data["quantization_config"]["desc_act"]
|
desc_act = data["quantization_config"].get("desc_act", False)
|
||||||
|
modules_to_not_convert = data["quantization_config"].get(
|
||||||
|
"modules_to_not_convert", []
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
@ -110,6 +115,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
sym=sym,
|
sym=sym,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
weight_block_size=weight_block_size,
|
weight_block_size=weight_block_size,
|
||||||
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -159,6 +165,7 @@ def get_loader(
|
|||||||
quant_method=quantizer_config.quant_method,
|
quant_method=quantizer_config.quant_method,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
sym=quantizer_config.sym,
|
sym=quantizer_config.sym,
|
||||||
|
modules_to_not_convert=quantizer_config.modules_to_not_convert,
|
||||||
)
|
)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
from text_generation_server.layers.bnb import BNBWeight
|
from text_generation_server.layers.bnb import BNBWeight
|
||||||
|
Loading…
Reference in New Issue
Block a user