mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 07:52:06 +00:00
Functionning quantization script.
This commit is contained in:
parent
5a72715344
commit
a0a194c391
@ -150,14 +150,16 @@ def download_weights(
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
|
||||
|
||||
@app.command()
|
||||
def quantize(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
revision: Optional[str] = None,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
):
|
||||
extension: str = ".safetensors",
|
||||
extension: str = (".safetensors",)
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
@ -169,12 +171,15 @@ def quantize(
|
||||
backtrace=True,
|
||||
diagnose=False,
|
||||
)
|
||||
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
logger_level=logger_level,
|
||||
json_output=json_output,
|
||||
)
|
||||
from text_generation_server.utils.gptq.quantize import quantize
|
||||
quantize(model_id=model_id, wbits=4, groupsize=128)
|
||||
|
||||
|
||||
|
||||
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -248,7 +248,9 @@ def get_model(
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
raise ValueError("gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
raise ValueError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
|
@ -59,7 +59,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
if isinstance(weight, torch.Tensor):
|
||||
# Only on non quantized versions
|
||||
weight = (
|
||||
weight.view(
|
||||
@ -75,7 +75,6 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
|
@ -1,4 +1,4 @@
|
||||
#https://github.com/fpgaminer/GPTQ-triton
|
||||
# https://github.com/fpgaminer/GPTQ-triton
|
||||
"""
|
||||
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
||||
"""
|
||||
@ -12,15 +12,23 @@ import triton
|
||||
|
||||
|
||||
class Autotuner(triton.KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
prune_configs_by: Dict = None,
|
||||
nearest_power_of_two: bool = False,
|
||||
):
|
||||
"""
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
||||
else:
|
||||
@ -41,9 +49,12 @@ class Autotuner(triton.KernelInterface):
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
perf_model, top_k = (
|
||||
prune_configs_by["perf_model"],
|
||||
prune_configs_by["top_k"],
|
||||
)
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
@ -55,8 +66,10 @@ class Autotuner(triton.KernelInterface):
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
@ -64,14 +77,21 @@ class Autotuner(triton.KernelInterface):
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**current,
|
||||
)
|
||||
|
||||
try:
|
||||
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
||||
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
||||
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
|
||||
return triton.testing.do_bench(
|
||||
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40
|
||||
)
|
||||
except triton.compiler.OutOfResources:
|
||||
return (float('inf'), float('inf'), float('inf'))
|
||||
return (float("inf"), float("inf"), float("inf"))
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
@ -81,13 +101,16 @@ class Autotuner(triton.KernelInterface):
|
||||
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
||||
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
||||
if self.nearest_power_of_two:
|
||||
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
|
||||
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
||||
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
timings = {
|
||||
config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs
|
||||
}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
@ -99,7 +122,13 @@ class Autotuner(triton.KernelInterface):
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
return self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
@ -110,8 +139,19 @@ class Autotuner(triton.KernelInterface):
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
est_timing = {
|
||||
config: self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
|
||||
:top_k
|
||||
]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@ -127,39 +167,49 @@ class Autotuner(triton.KernelInterface):
|
||||
self.nargs = None
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
|
||||
def autotune(
|
||||
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
|
||||
):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
|
||||
return Autotuner(
|
||||
fn,
|
||||
fn.arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
prune_configs_by,
|
||||
nearest_power_of_two,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
@ -168,26 +218,44 @@ def matmul248_kernel_config_pruner(configs, nargs):
|
||||
"""
|
||||
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
||||
"""
|
||||
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
|
||||
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
|
||||
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
|
||||
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
|
||||
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
|
||||
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
|
||||
|
||||
used = set()
|
||||
for config in configs:
|
||||
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
|
||||
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
|
||||
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
|
||||
group_size_m = config.kwargs['GROUP_SIZE_M']
|
||||
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
|
||||
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
|
||||
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
|
||||
group_size_m = config.kwargs["GROUP_SIZE_M"]
|
||||
|
||||
if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
|
||||
if (
|
||||
block_size_m,
|
||||
block_size_n,
|
||||
block_size_k,
|
||||
group_size_m,
|
||||
config.num_stages,
|
||||
config.num_warps,
|
||||
) in used:
|
||||
continue
|
||||
|
||||
used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
|
||||
yield triton.Config({
|
||||
'BLOCK_SIZE_M': block_size_m,
|
||||
'BLOCK_SIZE_N': block_size_n,
|
||||
'BLOCK_SIZE_K': block_size_k,
|
||||
'GROUP_SIZE_M': group_size_m
|
||||
},
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
||||
used.add(
|
||||
(
|
||||
block_size_m,
|
||||
block_size_n,
|
||||
block_size_k,
|
||||
group_size_m,
|
||||
config.num_stages,
|
||||
config.num_warps,
|
||||
)
|
||||
)
|
||||
yield triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
},
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
)
|
||||
|
@ -12,66 +12,121 @@ try:
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
def matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
@ -79,7 +134,7 @@ try:
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
@ -97,10 +152,15 @@ try:
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
@ -114,13 +174,17 @@ try:
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
scales = tl.load(
|
||||
scales_ptrs + g_idx[:, None] * stride_scales
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(
|
||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
zeros = zeros + 1
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
@ -136,61 +200,118 @@ try:
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
@custom_autotune.autotune(configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 256,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True)
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
)
|
||||
@triton.jit
|
||||
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
|
||||
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
def transpose_matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
@ -198,7 +319,7 @@ try:
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
@ -216,16 +337,25 @@ try:
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_bk[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_n[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
zeros_ptrs = (
|
||||
zeros_ptr
|
||||
+ (offs_n[None, :] // infearure_per_bits)
|
||||
+ g_idx[:, None] * stride_zeros
|
||||
)
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
@ -237,9 +367,9 @@ try:
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
zeros = zeros + 1
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
@ -251,36 +381,84 @@ try:
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
except:
|
||||
print('triton not installed.')
|
||||
print("triton not installed.")
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
|
||||
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
|
||||
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
output = torch.empty(
|
||||
(input.shape[0], output_dim), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(output_dim, META["BLOCK_SIZE_K"]),
|
||||
)
|
||||
transpose_matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
output_dim,
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
@ -297,7 +475,9 @@ class QuantLinearFunction(torch.autograd.Function):
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
grad_input = transpose_matmul248(
|
||||
grad_output, qweight, scales, qzeros, g_idx, bits, maxq
|
||||
)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -318,8 +498,41 @@ class QuantLinear(nn.Module):
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // 4
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = torch.zeros(
|
||||
(infeatures // 32 * self.bits, outfeatures), dtype=torch.int32
|
||||
)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures, )
|
||||
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
@ -4,26 +4,36 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import os
|
||||
|
||||
from texttable import Texttable
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
import numpy as np
|
||||
import torch
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
DEV = torch.device("cuda:0")
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
||||
self.register_buffer("maxq", torch.tensor(0))
|
||||
self.register_buffer("scale", torch.zeros(shape))
|
||||
self.register_buffer("zero", torch.zeros(shape))
|
||||
|
||||
def configure(
|
||||
self,
|
||||
bits,
|
||||
perchannel=False,
|
||||
sym=True,
|
||||
mse=False,
|
||||
norm=2.4,
|
||||
grid=100,
|
||||
maxshrink=0.8,
|
||||
trits=False,
|
||||
):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
@ -84,14 +94,16 @@ class Quantizer(nn.Module):
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q = self._quantize(
|
||||
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
|
||||
)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
@ -138,7 +150,6 @@ class Quantizer(nn.Module):
|
||||
|
||||
|
||||
class GPTQ:
|
||||
|
||||
def __init__(self, layer, observe=False):
|
||||
self.layer = layer
|
||||
self.dev = self.layer.weight.device
|
||||
@ -166,12 +177,19 @@ class GPTQ:
|
||||
if len(inp.shape) == 2:
|
||||
inp = inp.unsqueeze(0)
|
||||
tmp = inp.shape[0]
|
||||
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
||||
if isinstance(self.layer, nn.Linear) or isinstance(
|
||||
self.layer, transformers.Conv1D
|
||||
):
|
||||
if len(inp.shape) == 3:
|
||||
inp = inp.reshape((-1, inp.shape[-1]))
|
||||
inp = inp.t()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
|
||||
unfold = nn.Unfold(
|
||||
self.layer.kernel_size,
|
||||
dilation=self.layer.dilation,
|
||||
padding=self.layer.padding,
|
||||
stride=self.layer.stride,
|
||||
)
|
||||
inp = unfold(inp)
|
||||
inp = inp.permute([1, 0, 2])
|
||||
inp = inp.flatten(1)
|
||||
@ -184,12 +202,14 @@ class GPTQ:
|
||||
|
||||
def print_loss(self, name, q_weight, weight_error, timecost):
|
||||
table = Texttable()
|
||||
name += ' ' * (16 - len(name))
|
||||
name += " " * (16 - len(name))
|
||||
|
||||
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
|
||||
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
|
||||
|
||||
# assign weight
|
||||
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
||||
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
|
||||
self.layer.weight.data.dtype
|
||||
)
|
||||
|
||||
if self.inp1 is not None:
|
||||
# quantize input to int8
|
||||
@ -203,13 +223,15 @@ class GPTQ:
|
||||
q_SNR = torch_snr_error(q_out, self.out1).item()
|
||||
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
||||
else:
|
||||
q_SNR = '-'
|
||||
fp_SNR = '-'
|
||||
q_SNR = "-"
|
||||
fp_SNR = "-"
|
||||
|
||||
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
||||
print(table.draw().split('\n')[-2])
|
||||
print(table.draw().split("\n")[-2])
|
||||
|
||||
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
|
||||
def fasterquant(
|
||||
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
|
||||
):
|
||||
self.layer.to(self.dev)
|
||||
|
||||
W = self.layer.weight.data.clone()
|
||||
@ -268,7 +290,9 @@ class GPTQ:
|
||||
|
||||
if groupsize != -1:
|
||||
if (i1 + i) % groupsize == 0:
|
||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
||||
self.quantizer.find_params(
|
||||
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
||||
)
|
||||
|
||||
if ((i1 + i) // groupsize) - now_idx == -1:
|
||||
scale.append(self.quantizer.scale)
|
||||
@ -277,7 +301,7 @@ class GPTQ:
|
||||
|
||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||
Q1[:, i] = q
|
||||
Losses1[:, i] = (w - q)**2 / d**2
|
||||
Losses1[:, i] = (w - q) ** 2 / d**2
|
||||
|
||||
err1 = (w - q) / d
|
||||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
||||
@ -302,7 +326,9 @@ class GPTQ:
|
||||
if isinstance(self.layer, transformers.Conv1D):
|
||||
Q = Q.t()
|
||||
|
||||
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
|
||||
self.print_loss(
|
||||
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
|
||||
)
|
||||
|
||||
if scale == []:
|
||||
scale.append(self.quantizer.scale)
|
||||
@ -322,15 +348,18 @@ class GPTQ:
|
||||
|
||||
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
||||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
||||
|
||||
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
||||
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
|
||||
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
|
||||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
||||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
||||
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
@ -345,18 +374,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||
|
||||
def get_ptb(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
||||
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
|
||||
|
||||
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
|
||||
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
|
||||
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
||||
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
||||
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
@ -371,22 +403,37 @@ def get_ptb(nsamples, seed, seqlen, model_id):
|
||||
|
||||
def get_c4(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
|
||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
|
||||
|
||||
traindata = load_dataset(
|
||||
"allenai/c4",
|
||||
"allenai--c4",
|
||||
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||
split="train",
|
||||
use_auth_token=False,
|
||||
)
|
||||
valdata = load_dataset(
|
||||
"allenai/c4",
|
||||
"allenai--c4",
|
||||
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||
split="validation",
|
||||
use_auth_token=False,
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
while True:
|
||||
i = random.randint(0, len(traindata) - 1)
|
||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
||||
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||
if trainenc.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
@ -397,12 +444,13 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
||||
trainloader.append((inp, tar))
|
||||
|
||||
import random
|
||||
|
||||
random.seed(0)
|
||||
valenc = []
|
||||
for _ in range(256):
|
||||
while True:
|
||||
i = random.randint(0, len(valdata) - 1)
|
||||
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
|
||||
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
|
||||
if tmp.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
||||
@ -411,7 +459,6 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
||||
valenc = torch.hstack(valenc)
|
||||
|
||||
class TokenizerWrapper:
|
||||
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
@ -422,18 +469,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
||||
|
||||
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
||||
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||
|
||||
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
|
||||
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
|
||||
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
||||
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
||||
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
@ -448,22 +498,35 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||
|
||||
def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
|
||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
|
||||
|
||||
traindata = load_dataset(
|
||||
"allenai/c4",
|
||||
"allenai--c4",
|
||||
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||
split="train",
|
||||
)
|
||||
valdata = load_dataset(
|
||||
"allenai/c4",
|
||||
"allenai--c4",
|
||||
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||
split="validation",
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
while True:
|
||||
i = random.randint(0, len(traindata) - 1)
|
||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
||||
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||
if trainenc.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
@ -473,11 +536,10 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
|
||||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
|
||||
valenc = valenc.input_ids[:, :(256 * seqlen)]
|
||||
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
||||
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
||||
|
||||
class TokenizerWrapper:
|
||||
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
@ -486,31 +548,46 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''):
|
||||
if 'wikitext2' in name:
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
|
||||
if "wikitext2" in name:
|
||||
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
||||
if 'ptb' in name:
|
||||
if 'new' in name:
|
||||
if "ptb" in name:
|
||||
if "new" in name:
|
||||
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
||||
return get_ptb(nsamples, seed, seqlen, model_id)
|
||||
if 'c4' in name:
|
||||
if 'new' in name:
|
||||
if "c4" in name:
|
||||
if "new" in name:
|
||||
return get_c4_new(nsamples, seed, seqlen, model_id)
|
||||
return get_c4(nsamples, seed, seqlen, model_id)
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
|
||||
# Skip last lm_head linear
|
||||
if type(module) in layers and "lm_head" not in name:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
res.update(
|
||||
find_layers(
|
||||
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False):
|
||||
print('Starting ...')
|
||||
def sequential(
|
||||
model,
|
||||
dataloader,
|
||||
dev,
|
||||
nsamples,
|
||||
bits,
|
||||
groupsize,
|
||||
percdamp=0.01,
|
||||
sym: bool = False,
|
||||
act_order: bool = False,
|
||||
):
|
||||
print("Starting ...")
|
||||
|
||||
use_cache = model.config.use_cache
|
||||
model.config.use_cache = False
|
||||
@ -524,20 +601,21 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
# layers[0] = layers[0].to(dev)
|
||||
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
||||
cache = {'i': 0, 'attention_mask': None}
|
||||
inps = torch.zeros(
|
||||
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
||||
)
|
||||
cache = {"i": 0, "attention_mask": None}
|
||||
|
||||
class Catcher(nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, inp, **kwargs):
|
||||
inps[cache['i']] = inp
|
||||
cache['i'] += 1
|
||||
cache['attention_mask'] = kwargs['attention_mask']
|
||||
cache['position_ids'] = kwargs['position_ids']
|
||||
inps[cache["i"]] = inp
|
||||
cache["i"] += 1
|
||||
cache["attention_mask"] = kwargs["attention_mask"]
|
||||
cache["position_ids"] = kwargs["position_ids"]
|
||||
raise ValueError
|
||||
|
||||
layers[0] = Catcher(layers[0])
|
||||
@ -554,20 +632,20 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
outs = torch.zeros_like(inps)
|
||||
attention_mask = cache['attention_mask'].to(dev)
|
||||
position_ids = cache['position_ids'].to(dev)
|
||||
attention_mask = cache["attention_mask"].to(dev)
|
||||
position_ids = cache["position_ids"].to(dev)
|
||||
|
||||
print('Ready.')
|
||||
print("Ready.")
|
||||
|
||||
quantizers = {}
|
||||
for i in range(len(layers)):
|
||||
|
||||
print(f'Quantizing layer {i+1}/{len(layers)}..')
|
||||
print('+------------------+--------------+------------+-----------+-------+')
|
||||
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
|
||||
print('+==================+==============+============+===========+=======+')
|
||||
print(f"Quantizing layer {i+1}/{len(layers)}..")
|
||||
print("+------------------+--------------+------------+-----------+-------+")
|
||||
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||
print("+==================+==============+============+===========+=======+")
|
||||
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
|
||||
layer = layers[i].to(dev)
|
||||
remove_hook_from_submodules(layer)
|
||||
full = find_layers(layer)
|
||||
@ -578,10 +656,11 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
gptq = {}
|
||||
for name in subset:
|
||||
gptq[name] = GPTQ(subset[name])
|
||||
gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False)
|
||||
gptq[name].quantizer.configure(
|
||||
bits, perchannel=True, sym=sym, mse=False
|
||||
)
|
||||
|
||||
def add_batch(name):
|
||||
|
||||
def tmp(_, inp, out):
|
||||
gptq[name].add_batch(inp[0].data, out.data)
|
||||
|
||||
@ -591,19 +670,38 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
for name in subset:
|
||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||
for j in range(nsamples):
|
||||
|
||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
outs[j] = layer(
|
||||
inps[j].unsqueeze(0),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)[0]
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
for name in subset:
|
||||
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name)
|
||||
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
|
||||
scale, zero, g_idx, error = gptq[name].fasterquant(
|
||||
percdamp=percdamp,
|
||||
groupsize=groupsize,
|
||||
actorder=act_order,
|
||||
name=name,
|
||||
)
|
||||
quantizers["model.layers.%d.%s" % (i, name)] = (
|
||||
gptq[name].quantizer.cpu(),
|
||||
scale.cpu(),
|
||||
zero.cpu(),
|
||||
g_idx.cpu(),
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
|
||||
gptq[name].free()
|
||||
|
||||
for j in range(nsamples):
|
||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
outs[j] = layer(
|
||||
inps[j].unsqueeze(0),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)[0]
|
||||
|
||||
layers[i] = layer.cpu()
|
||||
del layer
|
||||
@ -611,12 +709,12 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
inps, outs = outs, inps
|
||||
print('+------------------+--------------+------------+-----------+-------+')
|
||||
print('\n')
|
||||
print("+------------------+--------------+------------+-----------+-------+")
|
||||
print("\n")
|
||||
|
||||
# if args.observe:
|
||||
# observer.print()
|
||||
# conditions = gen_conditions(args.wbits, args.groupsize)
|
||||
# conditions = gen_conditions(args.bits, args.groupsize)
|
||||
# for item in observer.items():
|
||||
# name = item[0]
|
||||
# layerid = item[1]
|
||||
@ -625,23 +723,23 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
# target = error / 2
|
||||
|
||||
# table = Texttable()
|
||||
# table.header(['wbits', 'groupsize', 'error'])
|
||||
# table.header(['bits', 'groupsize', 'error'])
|
||||
# table.set_cols_dtype(['i', 'i', 'f'])
|
||||
# table.add_row([args.wbits, args.groupsize, error])
|
||||
# table.add_row([args.bits, args.groupsize, error])
|
||||
|
||||
# print('Optimizing {} {} ..'.format(name, layerid))
|
||||
# for wbits, groupsize in conditions:
|
||||
# for bits, groupsize in conditions:
|
||||
|
||||
# if error < target:
|
||||
# # if error dropped 50%, skip
|
||||
# break
|
||||
|
||||
# gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False)
|
||||
# gptq.quantizer.configure(bits, perchannel=True, sym=args.sym, mse=False)
|
||||
|
||||
# scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
|
||||
|
||||
# table.add_row([wbits, groupsize, error])
|
||||
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
|
||||
# table.add_row([bits, groupsize, error])
|
||||
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
|
||||
|
||||
# print(table.draw())
|
||||
# print('\n')
|
||||
@ -656,34 +754,34 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
# @torch.no_grad()
|
||||
# def llama_eval(model, testenc, dev):
|
||||
# print('Evaluating ...')
|
||||
#
|
||||
#
|
||||
# testenc = testenc.input_ids
|
||||
# nsamples = testenc.numel() // model.seqlen
|
||||
#
|
||||
#
|
||||
# use_cache = model.config.use_cache
|
||||
# model.config.use_cache = False
|
||||
# layers = model.model.layers
|
||||
#
|
||||
#
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
||||
# layers[0] = layers[0].to(dev)
|
||||
#
|
||||
#
|
||||
# dtype = next(iter(model.parameters())).dtype
|
||||
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
||||
# cache = {'i': 0, 'attention_mask': None}
|
||||
#
|
||||
#
|
||||
# class Catcher(nn.Module):
|
||||
#
|
||||
#
|
||||
# def __init__(self, module):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
#
|
||||
#
|
||||
# def forward(self, inp, **kwargs):
|
||||
# inps[cache['i']] = inp
|
||||
# cache['i'] += 1
|
||||
# cache['attention_mask'] = kwargs['attention_mask']
|
||||
# cache['position_ids'] = kwargs['position_ids']
|
||||
# raise ValueError
|
||||
#
|
||||
#
|
||||
# layers[0] = Catcher(layers[0])
|
||||
# for i in range(nsamples):
|
||||
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
||||
@ -692,39 +790,39 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
# except ValueError:
|
||||
# pass
|
||||
# layers[0] = layers[0].module
|
||||
#
|
||||
#
|
||||
# layers[0] = layers[0].cpu()
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
#
|
||||
# outs = torch.zeros_like(inps)
|
||||
# attention_mask = cache['attention_mask']
|
||||
# position_ids = cache['position_ids']
|
||||
#
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# print(i)
|
||||
# layer = layers[i].to(dev)
|
||||
#
|
||||
#
|
||||
# if args.nearest:
|
||||
# subset = find_layers(layer)
|
||||
# for name in subset:
|
||||
# quantizer = quant.Quantizer()
|
||||
# quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
||||
# quantizer.configure(args.bits, perchannel=True, sym=args.sym, mse=False)
|
||||
# W = subset[name].weight.data
|
||||
# quantizer.find_params(W, weight=True)
|
||||
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
||||
#
|
||||
#
|
||||
# for j in range(nsamples):
|
||||
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
# layers[i] = layer.cpu()
|
||||
# del layer
|
||||
# torch.cuda.empty_cache()
|
||||
# inps, outs = outs, inps
|
||||
#
|
||||
#
|
||||
# if model.model.norm is not None:
|
||||
# model.model.norm = model.model.norm.to(dev)
|
||||
# model.lm_head = model.lm_head.to(dev)
|
||||
#
|
||||
#
|
||||
# testenc = testenc.to(dev)
|
||||
# nlls = []
|
||||
# for i in range(nsamples):
|
||||
@ -740,36 +838,61 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
|
||||
# nlls.append(neg_log_likelihood)
|
||||
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
||||
# print(ppl.item())
|
||||
#
|
||||
#
|
||||
# model.config.use_cache = use_cache
|
||||
|
||||
|
||||
def make_quant_linear(module, names, bits, groupsize, name=""):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + "." + attr if name != "" else attr
|
||||
if name1 in names:
|
||||
delattr(module, attr)
|
||||
setattr(
|
||||
module,
|
||||
attr,
|
||||
QuantLinear.new(
|
||||
bits,
|
||||
groupsize,
|
||||
tmp.in_features,
|
||||
tmp.out_features,
|
||||
tmp.bias is not None,
|
||||
),
|
||||
)
|
||||
for name1, child in module.named_children():
|
||||
make_quant_linear(
|
||||
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
|
||||
)
|
||||
|
||||
|
||||
# TODO: perform packing on GPU
|
||||
def pack(model, quantizers, wbits, groupsize):
|
||||
def pack(model, quantizers, bits, groupsize):
|
||||
layers = find_layers(model)
|
||||
layers = {n: layers[n] for n in quantizers}
|
||||
quant.make_quant_linear(model, quantizers, wbits, groupsize)
|
||||
make_quant_linear(model, quantizers, bits, groupsize)
|
||||
qlayers = find_layers(model, [QuantLinear])
|
||||
print('Packing ...')
|
||||
print("Packing ...")
|
||||
for name in qlayers:
|
||||
print(name)
|
||||
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
||||
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
||||
print('Done.')
|
||||
print("Done.")
|
||||
return model
|
||||
|
||||
|
||||
# def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
||||
# def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
||||
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
|
||||
# config = LlamaConfig.from_pretrained(model)
|
||||
#
|
||||
#
|
||||
# def noop(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# torch.nn.init.kaiming_uniform_ = noop
|
||||
# torch.nn.init.uniform_ = noop
|
||||
# torch.nn.init.normal_ = noop
|
||||
#
|
||||
#
|
||||
# torch.set_default_dtype(torch.half)
|
||||
# modeling_utils._init_weights = False
|
||||
# torch.set_default_dtype(torch.half)
|
||||
@ -781,30 +904,30 @@ def pack(model, quantizers, wbits, groupsize):
|
||||
# for name in ['lm_head']:
|
||||
# if name in layers:
|
||||
# del layers[name]
|
||||
# quant.make_quant_linear(model, layers, wbits, groupsize)
|
||||
#
|
||||
# quant.make_quant_linear(model, layers, bits, groupsize)
|
||||
#
|
||||
# del layers
|
||||
#
|
||||
#
|
||||
# print('Loading model ...')
|
||||
# if checkpoint.endswith('.safetensors'):
|
||||
# from safetensors.torch import load_file as safe_load
|
||||
# model.load_state_dict(safe_load(checkpoint))
|
||||
# else:
|
||||
# model.load_state_dict(torch.load(checkpoint))
|
||||
#
|
||||
#
|
||||
# if eval:
|
||||
# quant.make_quant_attn(model)
|
||||
# quant.make_quant_norm(model)
|
||||
# if fused_mlp:
|
||||
# quant.make_fused_mlp(model)
|
||||
#
|
||||
#
|
||||
# if warmup_autotune:
|
||||
# quant.autotune_warmup_linear(model, transpose=not (eval))
|
||||
# if eval and fused_mlp:
|
||||
# quant.autotune_warmup_fused(model)
|
||||
# model.seqlen = 2048
|
||||
# print('Done.')
|
||||
#
|
||||
#
|
||||
# return model
|
||||
|
||||
|
||||
@ -814,33 +937,33 @@ def pack(model, quantizers, wbits, groupsize):
|
||||
# model.model.norm = model.model.norm.to(gpus[0])
|
||||
# import copy
|
||||
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
|
||||
#
|
||||
#
|
||||
# cache = {'mask': None, 'position_ids': None}
|
||||
#
|
||||
#
|
||||
# class MoveModule(nn.Module):
|
||||
#
|
||||
#
|
||||
# def __init__(self, module, invalidate_cache):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
# self.dev = next(iter(self.module.parameters())).device
|
||||
# self.invalidate_cache=invalidate_cache
|
||||
#
|
||||
#
|
||||
# def forward(self, *inp, **kwargs):
|
||||
# inp = list(inp)
|
||||
# if inp[0].device != self.dev:
|
||||
# inp[0] = inp[0].to(self.dev)
|
||||
#
|
||||
#
|
||||
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
|
||||
# cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
||||
# kwargs['attention_mask'] = cache['mask']
|
||||
#
|
||||
#
|
||||
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
|
||||
# cache['position_ids'] = kwargs['position_ids'].to(self.dev)
|
||||
# kwargs['position_ids'] = cache['position_ids']
|
||||
#
|
||||
#
|
||||
# tmp = self.module(*inp, **kwargs)
|
||||
# return tmp
|
||||
#
|
||||
#
|
||||
# layers = model.model.layers
|
||||
# from math import ceil
|
||||
# if not gpu_dist:
|
||||
@ -852,49 +975,49 @@ def pack(model, quantizers, wbits, groupsize):
|
||||
# assigned_gpus = [0] * (gpu_dist[0]-1)
|
||||
# for i in range(1, len(gpu_dist)):
|
||||
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
|
||||
#
|
||||
#
|
||||
# remaining_assignments = len(layers)-len(assigned_gpus) - 1
|
||||
# if remaining_assignments > 0:
|
||||
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments
|
||||
#
|
||||
#
|
||||
# assigned_gpus = assigned_gpus + [0]
|
||||
#
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
|
||||
#
|
||||
#
|
||||
# model.gpus = gpus
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def benchmark(model, input_ids, check=False):
|
||||
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
#
|
||||
# cache = {'past': None}
|
||||
#
|
||||
#
|
||||
# def clear_past(i):
|
||||
#
|
||||
#
|
||||
# def tmp(layer, inp, out):
|
||||
# if cache['past']:
|
||||
# cache['past'][i] = None
|
||||
#
|
||||
#
|
||||
# return tmp
|
||||
#
|
||||
#
|
||||
# for i, layer in enumerate(model.model.layers):
|
||||
# layer.register_forward_hook(clear_past(i))
|
||||
#
|
||||
#
|
||||
# print('Benchmarking ...')
|
||||
#
|
||||
#
|
||||
# if check:
|
||||
# loss = nn.CrossEntropyLoss()
|
||||
# tot = 0.
|
||||
#
|
||||
#
|
||||
# def sync():
|
||||
# if hasattr(model, 'gpus'):
|
||||
# for gpu in model.gpus:
|
||||
# torch.cuda.synchronize(gpu)
|
||||
# else:
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
#
|
||||
# max_memory = 0
|
||||
# with torch.no_grad():
|
||||
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
||||
@ -921,9 +1044,11 @@ def pack(model, quantizers, wbits, groupsize):
|
||||
# print('max memory(MiB):', max_memory)
|
||||
|
||||
|
||||
def quantize(model_id: str, wbits: int, groupsize: int):
|
||||
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, device_map="balanced_low_0"
|
||||
)
|
||||
print("LOADED model")
|
||||
model.seqlen = 2048
|
||||
|
||||
@ -931,11 +1056,12 @@ def quantize(model_id: str, wbits: int, groupsize: int):
|
||||
nsamples = 128
|
||||
seed = None
|
||||
|
||||
|
||||
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen)
|
||||
dataloader, testloader = get_loaders(
|
||||
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
|
||||
)
|
||||
|
||||
tick = time.time()
|
||||
quantizers = sequential(model, dataloader, DEV, nsamples, wbits, groupsize)
|
||||
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
||||
print(time.time() - tick)
|
||||
|
||||
# if args.benchmark:
|
||||
@ -956,7 +1082,7 @@ def quantize(model_id: str, wbits: int, groupsize: int):
|
||||
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
# print(dataset)
|
||||
# llama_eval(model, testloader, DEV)
|
||||
#
|
||||
#
|
||||
# if args.test_generation:
|
||||
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
||||
# if len(gpus) > 1:
|
||||
@ -970,20 +1096,57 @@ def quantize(model_id: str, wbits: int, groupsize: int):
|
||||
# streamer = TextStreamer(tokenizer)
|
||||
# with torch.no_grad():
|
||||
# generated_ids = model.generate(input_ids, streamer=streamer)
|
||||
#
|
||||
|
||||
#
|
||||
|
||||
# if args.quant_directory is not None:
|
||||
# export_quant_table(quantizers, args.quant_directory)
|
||||
|
||||
# if not args.observe and args.save:
|
||||
# llama_pack(model, quantizers, args.wbits, args.groupsize)
|
||||
# llama_pack(model, quantizers, args.bits, args.groupsize)
|
||||
# torch.save(model.state_dict(), args.save)
|
||||
|
||||
# if not args.observe and args.save_safetensors:
|
||||
pack(model, quantizers, wbits, groupsize)
|
||||
from safetensors.torch import save_file as safe_save
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
safe_save(state_dict, args.save_safetensors)
|
||||
pack(model, quantizers, bits, groupsize)
|
||||
from safetensors.torch import save_file
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||
state_dict["gptq_bits"] = torch.LongTensor(bits)
|
||||
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
|
||||
|
||||
shards, index = shard_checkpoint(
|
||||
state_dict, max_shard_size="10GB", weights_name="model.safetensors"
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for shard_file, shard in shards.items():
|
||||
save_file(
|
||||
shard,
|
||||
os.path.join(output_dir, shard_file),
|
||||
metadata={
|
||||
"format": "pt",
|
||||
"quantized": "gptq",
|
||||
"origin": "text-generation-inference",
|
||||
},
|
||||
)
|
||||
if index is None:
|
||||
path_to_weights = os.path.join(save_directory, "model.safetensors")
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
else:
|
||||
save_index_file = "model.safetensors.index.json"
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
config.save_pretrained(output_dir)
|
||||
logger.info("Saved config")
|
||||
logger.info("Saving tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
logger.info("Saved tokenizer")
|
||||
|
@ -134,13 +134,15 @@ def get_linear(weight, bias, quantize):
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.")
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
|
@ -86,19 +86,27 @@ class Weights:
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
||||
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -110,14 +118,15 @@ class Weights:
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user