From 75b55efcc757423d92b07ab74684f46b23eeb1ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 24 Jul 2024 13:27:20 +0000 Subject: [PATCH] server quantize: store quantizer config in standard format - Create `quantization_config` option in the model config. - Don't store the quantizer config in tensors anymore. --- .../layers/gptq/quantize.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 0271d913..c199023f 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -16,7 +16,7 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight DEV = torch.device("cuda:0") @@ -894,7 +894,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, - weights_loader=DefaultWeightsLoader(), + weights_loader=DefaultWeightsLoader(UnquantizedWeight), ) hooks = [] for name, module in model.named_modules(): @@ -957,9 +957,6 @@ def quantize( state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor([bits]) - state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) - state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( @@ -991,6 +988,15 @@ def quantize( f"index located at {save_index_file}." ) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.quantization_config = { + "bits": bits, + "group_size": groupsize, + "damp_percent": percdamp, + "desc_act": act_order, + "static_groups": False, + "sym": sym, + "quant_method": "gptq", + } config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer")