diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b4539c46..8d2a1ce7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -150,7 +150,6 @@ def download_weights( # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files) - @app.command() def quantize( model_id: str, @@ -158,8 +157,9 @@ def quantize( revision: Optional[str] = None, logger_level: str = "INFO", json_output: bool = False, + trust_remote_code: bool = False, ): - extension: str = (".safetensors",) + extension: str = ".safetensors", # Remove default handler logger.remove() logger.add( @@ -171,15 +171,12 @@ def quantize( backtrace=True, diagnose=False, ) - download_weights( - model_id=model_id, - revision=revision, - logger_level=logger_level, - json_output=json_output, - ) + download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output) from text_generation_server.utils.gptq.quantize import quantize + quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code) + + - quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir) if __name__ == "__main__": diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 3093a700..738f54b7 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -304,11 +304,14 @@ class QuantLinearFunction(torch.autograd.Function): class QuantLinear(nn.Module): def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): super().__init__() - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx - self.bias = bias + self.qweight = self.register_buffer("qweight", qweight) + self.qzeros = self.register_buffer("qzeros", qzeros) + self.scales = self.register_buffer("scales", scales) + self.g_idx = self.register_buffer("g_idx", g_idx) + if bias is not None: + self.bias = self.register_buffer("bias", bias) + else: + self.bias = None if bits not in [2, 4, 8]: raise NotImplementedError("Only 2,4,8 bits are supported.") self.bits = bits diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index e3fae470..93d86b3a 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -937,9 +937,9 @@ def pack(model, quantizers, bits, groupsize): # print('max memory(MiB):', max_memory) -def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): +def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool): print("loading model") - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0") + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code) print("LOADED model") model.seqlen = 2048 @@ -1002,8 +1002,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): from transformers.modeling_utils import shard_checkpoint state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor(bits) - state_dict["gptq_groupsize"] = torch.LongTensor(groupsize) + state_dict["gptq_bits"] = torch.LongTensor([bits]) + state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) max_shard_size = "10GB" shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")