mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Reworking the quantization script so it's still universal (not llama
specific) but should work on more configurations (no need for 2 GPUs, less RAM usage).
This commit is contained in:
parent
2c4bf88268
commit
b3f830abc3
@ -194,6 +194,8 @@ def quantize(
|
||||
percdamp: float = 0.01,
|
||||
act_order: bool = False,
|
||||
):
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
@ -207,6 +209,7 @@ def quantize(
|
||||
bits=4,
|
||||
groupsize=128,
|
||||
output_dir=output_dir,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
upload_to_model_id=upload_to_model_id,
|
||||
percdamp=percdamp,
|
||||
|
@ -13,6 +13,9 @@ import transformers
|
||||
from huggingface_hub import HfApi
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||
from text_generation_server.utils.hub import weight_files
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
|
||||
maxshrink=0.8,
|
||||
trits=False,
|
||||
):
|
||||
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
@ -600,6 +602,8 @@ def sequential(
|
||||
nsamples,
|
||||
bits,
|
||||
groupsize,
|
||||
*,
|
||||
hooks,
|
||||
percdamp=0.01,
|
||||
sym: bool = False,
|
||||
act_order: bool = False,
|
||||
@ -637,7 +641,7 @@ def sequential(
|
||||
layers[0] = Catcher(layers[0])
|
||||
for batch in dataloader:
|
||||
try:
|
||||
model(batch[0])
|
||||
model(batch[0].cuda())
|
||||
except ValueError:
|
||||
pass
|
||||
layers[0] = layers[0].module
|
||||
@ -646,6 +650,8 @@ def sequential(
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# model.model.norm = model.model.norm.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
outs = torch.zeros_like(inps)
|
||||
|
||||
@ -662,10 +668,8 @@ def sequential(
|
||||
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||
print("+==================+==============+============+===========+=======+")
|
||||
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
|
||||
layer = layers[i].to(dev)
|
||||
remove_hook_from_submodules(layer)
|
||||
layer = layers[i]
|
||||
layer.load()
|
||||
full = find_layers(layer)
|
||||
sequential = [list(full.keys())]
|
||||
|
||||
@ -677,6 +681,7 @@ def sequential(
|
||||
gptq[name].quantizer.configure(
|
||||
bits, perchannel=True, sym=sym, mse=False
|
||||
)
|
||||
pass
|
||||
|
||||
def add_batch(name):
|
||||
def tmp(_, inp, out):
|
||||
@ -688,7 +693,6 @@ def sequential(
|
||||
for name in subset:
|
||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||
for j in range(nsamples):
|
||||
|
||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||
for h in handles:
|
||||
h.remove()
|
||||
@ -714,7 +718,7 @@ def sequential(
|
||||
for j in range(nsamples):
|
||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||
|
||||
layers[i] = layer.cpu()
|
||||
layer.unload()
|
||||
del layer
|
||||
del gptq
|
||||
torch.cuda.empty_cache()
|
||||
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
|
||||
return model
|
||||
|
||||
|
||||
def setdeepattr(module, full_name, tensor):
|
||||
current = module
|
||||
tokens = full_name.split(".")
|
||||
for token in tokens[:-1]:
|
||||
current = getattr(current, token)
|
||||
setattr(current, tokens[-1], tensor)
|
||||
|
||||
|
||||
def getdeepattr(module, full_name):
|
||||
current = module
|
||||
tokens = full_name.split(".")
|
||||
for token in tokens:
|
||||
current = getattr(current, token)
|
||||
return current
|
||||
|
||||
|
||||
def load_weights_pre_hook(module_name, weights, recursive=False):
|
||||
def inner(module, args):
|
||||
print(f"Pre hook {module_name}")
|
||||
local_params = {}
|
||||
for k, v in module.named_parameters():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for k, v in module.named_buffers():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
|
||||
for local_param in local_params:
|
||||
current_tensor = getdeepattr(module, local_param)
|
||||
if current_tensor.device == torch.device("meta"):
|
||||
# print(f"Loading {local_param}")
|
||||
if module_name:
|
||||
tensor_name = f"{module_name}.{local_param}"
|
||||
else:
|
||||
tensor_name = local_param
|
||||
tensor = weights.get_tensor(tensor_name)
|
||||
setdeepattr(module, local_param, nn.Parameter(tensor))
|
||||
else:
|
||||
setdeepattr(
|
||||
module,
|
||||
local_param,
|
||||
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def load_weights_post_hook(module_name, weights, recursive=False):
|
||||
def inner(module, args, output):
|
||||
print(f"Post hook {module_name}")
|
||||
local_params = {}
|
||||
for k, v in module.named_parameters():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for k, v in module.named_buffers():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for local_param in local_params:
|
||||
# print(f"Unloading {local_param}")
|
||||
current_tensor = getdeepattr(module, local_param)
|
||||
setdeepattr(
|
||||
module,
|
||||
local_param,
|
||||
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
||||
)
|
||||
return output
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def quantize(
|
||||
model_id: str,
|
||||
bits: int,
|
||||
groupsize: int,
|
||||
output_dir: str,
|
||||
revision: str,
|
||||
trust_remote_code: bool,
|
||||
upload_to_model_id: Optional[str],
|
||||
percdamp: float,
|
||||
act_order: bool,
|
||||
):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced_low_0",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
|
||||
model = model.eval()
|
||||
|
||||
print("LOADED model")
|
||||
files = weight_files(model_id, revision, extension=".safetensors")
|
||||
process_group, _, _ = initialize_torch_distributed()
|
||||
weights = Weights(
|
||||
files,
|
||||
device=torch.device("cuda:0"),
|
||||
dtype=torch.float16,
|
||||
process_group=process_group,
|
||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||
)
|
||||
hooks = []
|
||||
for name, module in model.named_modules():
|
||||
|
||||
def load(module, name):
|
||||
def _load():
|
||||
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
||||
|
||||
return _load
|
||||
|
||||
def unload(module, name):
|
||||
def _unload():
|
||||
load_weights_post_hook(name, weights, recursive=True)(
|
||||
module, None, None
|
||||
)
|
||||
|
||||
return _unload
|
||||
|
||||
module.load = load(module, name)
|
||||
module.unload = unload(module, name)
|
||||
hooks.append(
|
||||
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
||||
)
|
||||
hooks.append(
|
||||
module.register_forward_hook(load_weights_post_hook(name, weights))
|
||||
)
|
||||
model.seqlen = 2048
|
||||
|
||||
dataset = "wikitext2"
|
||||
@ -806,6 +922,7 @@ def quantize(
|
||||
groupsize,
|
||||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
hooks=hooks,
|
||||
)
|
||||
print(time.time() - tick)
|
||||
|
||||
@ -858,7 +975,6 @@ def quantize(
|
||||
logger.info("Saved tokenizer")
|
||||
|
||||
if upload_to_model_id:
|
||||
|
||||
api = HfApi()
|
||||
|
||||
api.upload_folder(
|
||||
|
Loading…
Reference in New Issue
Block a user