Fixing few things

This commit is contained in:
Ubuntu 2023-06-13 18:58:09 +00:00 committed by Nicolas Patry
parent dadbbc27d5
commit ffe8fc4699
3 changed files with 18 additions and 18 deletions

View File

@ -150,7 +150,6 @@ def download_weights(
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files) utils.convert_files(local_pt_files, local_st_files)
@app.command() @app.command()
def quantize( def quantize(
model_id: str, model_id: str,
@ -158,8 +157,9 @@ def quantize(
revision: Optional[str] = None, revision: Optional[str] = None,
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
trust_remote_code: bool = False,
): ):
extension: str = (".safetensors",) extension: str = ".safetensors",
# Remove default handler # Remove default handler
logger.remove() logger.remove()
logger.add( logger.add(
@ -171,15 +171,12 @@ def quantize(
backtrace=True, backtrace=True,
diagnose=False, diagnose=False,
) )
download_weights( download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
from text_generation_server.utils.gptq.quantize import quantize from text_generation_server.utils.gptq.quantize import quantize
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -304,11 +304,14 @@ class QuantLinearFunction(torch.autograd.Function):
class QuantLinear(nn.Module): class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__() super().__init__()
self.qweight = qweight self.qweight = self.register_buffer("qweight", qweight)
self.qzeros = qzeros self.qzeros = self.register_buffer("qzeros", qzeros)
self.scales = scales self.scales = self.register_buffer("scales", scales)
self.g_idx = g_idx self.g_idx = self.register_buffer("g_idx", g_idx)
self.bias = bias if bias is not None:
self.bias = self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.") raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits self.bits = bits

View File

@ -937,9 +937,9 @@ def pack(model, quantizers, bits, groupsize):
# print('max memory(MiB):', max_memory) # print('max memory(MiB):', max_memory)
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0") model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code)
print("LOADED model") print("LOADED model")
model.seqlen = 2048 model.seqlen = 2048
@ -1002,8 +1002,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor(bits) state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors") shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")