Reworking the quantization script so it's still universal (not llama

specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).
This commit is contained in:
Nicolas Patry 2023-07-11 17:25:26 +00:00
parent 2c4bf88268
commit b3f830abc3
2 changed files with 131 additions and 12 deletions

View File

@ -194,6 +194,8 @@ def quantize(
percdamp: float = 0.01,
act_order: bool = False,
):
if revision is None:
revision = "main"
download_weights(
model_id=model_id,
revision=revision,
@ -207,6 +209,7 @@ def quantize(
bits=4,
groupsize=128,
output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,

View File

@ -13,6 +13,9 @@ import transformers
from huggingface_hub import HfApi
import numpy as np
import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@ -600,6 +602,8 @@ def sequential(
nsamples,
bits,
groupsize,
*,
hooks,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
@ -637,7 +641,7 @@ def sequential(
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0])
model(batch[0].cuda())
except ValueError:
pass
layers[0] = layers[0].module
@ -646,6 +650,8 @@ def sequential(
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
for hook in hooks:
hook.remove()
outs = torch.zeros_like(inps)
@ -662,10 +668,8 @@ def sequential(
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+")
from accelerate.hooks import remove_hook_from_submodules
layer = layers[i].to(dev)
remove_hook_from_submodules(layer)
layer = layers[i]
layer.load()
full = find_layers(layer)
sequential = [list(full.keys())]
@ -677,6 +681,7 @@ def sequential(
gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
pass
def add_batch(name):
def tmp(_, inp, out):
@ -688,7 +693,6 @@ def sequential(
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles:
h.remove()
@ -714,7 +718,7 @@ def sequential(
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layers[i] = layer.cpu()
layer.unload()
del layer
del gptq
torch.cuda.empty_cache()
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
return model
def setdeepattr(module, full_name, tensor):
current = module
tokens = full_name.split(".")
for token in tokens[:-1]:
current = getattr(current, token)
setattr(current, tokens[-1], tensor)
def getdeepattr(module, full_name):
current = module
tokens = full_name.split(".")
for token in tokens:
current = getattr(current, token)
return current
def load_weights_pre_hook(module_name, weights, recursive=False):
def inner(module, args):
print(f"Pre hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
current_tensor = getdeepattr(module, local_param)
if current_tensor.device == torch.device("meta"):
# print(f"Loading {local_param}")
if module_name:
tensor_name = f"{module_name}.{local_param}"
else:
tensor_name = local_param
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
)
return inner
def load_weights_post_hook(module_name, weights, recursive=False):
def inner(module, args, output):
print(f"Post hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
# print(f"Unloading {local_param}")
current_tensor = getdeepattr(module, local_param)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
)
return output
return inner
def quantize(
model_id: str,
bits: int,
groupsize: int,
output_dir: str,
revision: str,
trust_remote_code: bool,
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
):
print("loading model")
model = AutoModelForCausalLM.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="balanced_low_0",
trust_remote_code=trust_remote_code,
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
model = model.eval()
print("LOADED model")
files = weight_files(model_id, revision, extension=".safetensors")
process_group, _, _ = initialize_torch_distributed()
weights = Weights(
files,
device=torch.device("cuda:0"),
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
)
hooks = []
for name, module in model.named_modules():
def load(module, name):
def _load():
load_weights_pre_hook(name, weights, recursive=True)(module, None)
return _load
def unload(module, name):
def _unload():
load_weights_post_hook(name, weights, recursive=True)(
module, None, None
)
return _unload
module.load = load(module, name)
module.unload = unload(module, name)
hooks.append(
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
)
hooks.append(
module.register_forward_hook(load_weights_post_hook(name, weights))
)
model.seqlen = 2048
dataset = "wikitext2"
@ -806,6 +922,7 @@ def quantize(
groupsize,
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
)
print(time.time() - tick)
@ -858,7 +975,6 @@ def quantize(
logger.info("Saved tokenizer")
if upload_to_model_id:
api = HfApi()
api.upload_folder(