mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
server quantize: store quantizer config in standard format
- Create `quantization_config` option in the model config. - Don't store the quantizer config in tensors anymore.
This commit is contained in:
parent
8642250602
commit
75b55efcc7
@ -16,7 +16,7 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||||
|
|
||||||
DEV = torch.device("cuda:0")
|
DEV = torch.device("cuda:0")
|
||||||
|
|
||||||
@ -894,7 +894,7 @@ def quantize(
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
weights_loader=DefaultWeightsLoader(),
|
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
|
||||||
)
|
)
|
||||||
hooks = []
|
hooks = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@ -957,9 +957,6 @@ def quantize(
|
|||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
|
||||||
state_dict["gptq_sym"] = torch.BoolTensor([sym])
|
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
shards, index = shard_checkpoint(
|
||||||
@ -991,6 +988,15 @@ def quantize(
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||||
|
config.quantization_config = {
|
||||||
|
"bits": bits,
|
||||||
|
"group_size": groupsize,
|
||||||
|
"damp_percent": percdamp,
|
||||||
|
"desc_act": act_order,
|
||||||
|
"static_groups": False,
|
||||||
|
"sym": sym,
|
||||||
|
"quant_method": "gptq",
|
||||||
|
}
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
logger.info("Saved config")
|
logger.info("Saved config")
|
||||||
logger.info("Saving tokenizer")
|
logger.info("Saving tokenizer")
|
||||||
|
Loading…
Reference in New Issue
Block a user