Functionning quantization script.

This commit is contained in:
Ubuntu 2023-06-13 11:45:08 +00:00 committed by Nicolas Patry
parent 5a72715344
commit a0a194c391
8 changed files with 842 additions and 381 deletions

View File

@ -150,14 +150,16 @@ def download_weights(
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files) utils.convert_files(local_pt_files, local_st_files)
@app.command() @app.command()
def quantize( def quantize(
model_id: str, model_id: str,
output_dir: str,
revision: Optional[str] = None, revision: Optional[str] = None,
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
): ):
extension: str = ".safetensors", extension: str = (".safetensors",)
# Remove default handler # Remove default handler
logger.remove() logger.remove()
logger.add( logger.add(
@ -169,12 +171,15 @@ def quantize(
backtrace=True, backtrace=True,
diagnose=False, diagnose=False,
) )
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output) download_weights(
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
from text_generation_server.utils.gptq.quantize import quantize from text_generation_server.utils.gptq.quantize import quantize
quantize(model_id=model_id, wbits=4, groupsize=128)
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -248,7 +248,9 @@ def get_model(
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":
raise ValueError("gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(

View File

@ -59,7 +59,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize) weight = weights.get_multi_weights_col([prefix], quantize=config.quantize)
if isinstance(weight, torch.Tensor): if isinstance(weight, torch.Tensor):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight = (
weight.view( weight.view(
@ -75,7 +75,6 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear

View File

@ -1,4 +1,4 @@
#https://github.com/fpgaminer/GPTQ-triton # https://github.com/fpgaminer/GPTQ-triton
""" """
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
""" """
@ -12,15 +12,23 @@ import triton
class Autotuner(triton.KernelInterface): class Autotuner(triton.KernelInterface):
def __init__(
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): self,
''' fn,
:param prune_configs_by: a dict of functions that are used to prune configs, fields: arg_names,
'perf_model': performance model used to predicate running time with different configs, returns running time configs,
'top_k': number of configs to bench key,
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. reset_to_zero,
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results prune_configs_by: Dict = None,
''' nearest_power_of_two: bool = False,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
"""
if not configs: if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)] self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else: else:
@ -41,9 +49,12 @@ class Autotuner(triton.KernelInterface):
self.arg_names = arg_names self.arg_names = arg_names
# prune configs # prune configs
if prune_configs_by: if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] perf_model, top_k = (
if 'early_config_prune' in prune_configs_by: prune_configs_by["perf_model"],
early_config_prune = prune_configs_by['early_config_prune'] prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else: else:
perf_model, top_k, early_config_prune = None, None, None perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k self.perf_model, self.configs_top_k = perf_model, top_k
@ -55,8 +66,10 @@ class Autotuner(triton.KernelInterface):
# as kwargs and by the autotuner # as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys() conflicts = meta.keys() & config.kwargs.keys()
if conflicts: if conflicts:
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." raise ValueError(
" Make sure that you don't re-define auto-tuned symbols.") f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones # augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs) current = dict(meta, **config.kwargs)
@ -64,14 +77,21 @@ class Autotuner(triton.KernelInterface):
if config.pre_hook: if config.pre_hook:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
self.hook(args) self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try: try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) return triton.testing.do_bench(
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40
)
except triton.compiler.OutOfResources: except triton.compiler.OutOfResources:
return (float('inf'), float('inf'), float('inf')) return (float("inf"), float("inf"), float("inf"))
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args)) self.nargs = dict(zip(self.arg_names, args))
@ -81,13 +101,16 @@ class Autotuner(triton.KernelInterface):
# This reduces the amount of autotuning by rounding the keys to the nearest power of two # This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required # In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two: if self.nearest_power_of_two:
key = tuple([2**int(math.log2(x) + 0.5) for x in key]) key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache: if key not in self.cache:
# prune configs # prune configs
pruned_configs = self.prune_configs(kwargs) pruned_configs = self.prune_configs(kwargs)
bench_start = time.time() bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time() bench_end = time.time()
self.bench_time = bench_end - bench_start self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get) self.cache[key] = builtins.min(timings, key=timings.get)
@ -99,7 +122,13 @@ class Autotuner(triton.KernelInterface):
self.best_config = config self.best_config = config
if config.pre_hook is not None: if config.pre_hook is not None:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs): def prune_configs(self, kwargs):
pruned_configs = self.configs pruned_configs = self.configs
@ -110,8 +139,19 @@ class Autotuner(triton.KernelInterface):
if isinstance(top_k, float) and top_k <= 1.0: if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k) top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k: if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} est_timing = {
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
:top_k
]
return pruned_configs return pruned_configs
def warmup(self, *args, **kwargs): def warmup(self, *args, **kwargs):
@ -127,39 +167,49 @@ class Autotuner(triton.KernelInterface):
self.nargs = None self.nargs = None
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): def autotune(
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
""" """
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn): def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator return decorator
@ -168,26 +218,44 @@ def matmul248_kernel_config_pruner(configs, nargs):
""" """
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
""" """
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set() used = set()
for config in configs: for config in configs:
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs['GROUP_SIZE_M'] group_size_m = config.kwargs["GROUP_SIZE_M"]
if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue continue
used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) used.add(
yield triton.Config({ (
'BLOCK_SIZE_M': block_size_m, block_size_m,
'BLOCK_SIZE_N': block_size_n, block_size_n,
'BLOCK_SIZE_K': block_size_k, block_size_k,
'GROUP_SIZE_M': group_size_m group_size_m,
}, config.num_stages,
num_stages=config.num_stages, config.num_warps,
num_warps=config.num_warps) )
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)

View File

@ -12,66 +12,121 @@ try:
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune( @custom_autotune.autotune(
configs=[ configs=[
triton.Config({ triton.Config(
'BLOCK_SIZE_M': 64, {
'BLOCK_SIZE_N': 256, "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_N": 256,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_K": 32,
}, num_stages=4, num_warps=4), "GROUP_SIZE_M": 8,
triton.Config({ },
'BLOCK_SIZE_M': 128, num_stages=4,
'BLOCK_SIZE_N': 128, num_warps=4,
'BLOCK_SIZE_K': 32, ),
'GROUP_SIZE_M': 8 triton.Config(
}, num_stages=4, num_warps=4), {
triton.Config({ "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_N": 128,
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_K': 32, "GROUP_SIZE_M": 8,
'GROUP_SIZE_M': 8 },
}, num_stages=4, num_warps=4), num_stages=4,
triton.Config({ num_warps=4,
'BLOCK_SIZE_M': 128, ),
'BLOCK_SIZE_N': 32, triton.Config(
'BLOCK_SIZE_K': 32, {
'GROUP_SIZE_M': 8 "BLOCK_SIZE_M": 64,
}, num_stages=4, num_warps=4), "BLOCK_SIZE_N": 128,
triton.Config({ "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_M': 64, "GROUP_SIZE_M": 8,
'BLOCK_SIZE_N': 64, },
'BLOCK_SIZE_K': 32, num_stages=4,
'GROUP_SIZE_M': 8 num_warps=4,
}, num_stages=4, num_warps=4), ),
triton.Config({ triton.Config(
'BLOCK_SIZE_M': 64, {
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_N": 32,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_K": 32,
}, num_stages=2, num_warps=8), "GROUP_SIZE_M": 8,
triton.Config({ },
'BLOCK_SIZE_M': 64, num_stages=4,
'BLOCK_SIZE_N': 64, num_warps=4,
'BLOCK_SIZE_K': 64, ),
'GROUP_SIZE_M': 8 triton.Config(
}, num_stages=3, num_warps=8), {
triton.Config({ "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_M': 32, "BLOCK_SIZE_N": 64,
'BLOCK_SIZE_N': 32, "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_K': 128, "GROUP_SIZE_M": 8,
'GROUP_SIZE_M': 8 },
}, num_stages=2, num_warps=4), num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
nearest_power_of_two=True, nearest_power_of_two=True,
prune_configs_by={ prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, "early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None, "perf_model": None,
'top_k': None, "top_k": None,
}, },
) )
@triton.jit @triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, def matmul_248_kernel(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16 A is of shape (M, K) float16
@ -79,7 +134,7 @@ try:
C is of shape (M, N) float16 C is of shape (M, N) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
@ -97,10 +152,15 @@ try:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_ptrs = a_ptr + (
a_mask = (offs_am[:, None] < M) offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :] scales_ptrs = scales_ptr + offs_bn[None, :]
@ -114,13 +174,17 @@ try:
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales = tl.load(
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
@ -136,61 +200,118 @@ try:
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(configs=[ @custom_autotune.autotune(
triton.Config({ configs=[
'BLOCK_SIZE_M': 64, triton.Config(
'BLOCK_SIZE_N': 32, {
'BLOCK_SIZE_K': 256, "BLOCK_SIZE_M": 64,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_N": 32,
}, num_stages=4, num_warps=4), "BLOCK_SIZE_K": 256,
triton.Config({ "GROUP_SIZE_M": 8,
'BLOCK_SIZE_M': 128, },
'BLOCK_SIZE_N': 32, num_stages=4,
'BLOCK_SIZE_K': 128, num_warps=4,
'GROUP_SIZE_M': 8 ),
}, num_stages=4, num_warps=4), triton.Config(
triton.Config({ {
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_N': 32, "BLOCK_SIZE_N": 32,
'BLOCK_SIZE_K': 128, "BLOCK_SIZE_K": 128,
'GROUP_SIZE_M': 8 "GROUP_SIZE_M": 8,
}, num_stages=4, num_warps=4), },
triton.Config({ num_stages=4,
'BLOCK_SIZE_M': 128, num_warps=4,
'BLOCK_SIZE_N': 32, ),
'BLOCK_SIZE_K': 32, triton.Config(
'GROUP_SIZE_M': 8 {
}, num_stages=4, num_warps=4), "BLOCK_SIZE_M": 64,
triton.Config({ "BLOCK_SIZE_N": 32,
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_K": 128,
'BLOCK_SIZE_N': 32, "GROUP_SIZE_M": 8,
'BLOCK_SIZE_K': 64, },
'GROUP_SIZE_M': 8 num_stages=4,
}, num_stages=4, num_warps=4), num_warps=4,
triton.Config({ ),
'BLOCK_SIZE_M': 64, triton.Config(
'BLOCK_SIZE_N': 32, {
'BLOCK_SIZE_K': 128, "BLOCK_SIZE_M": 128,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_N": 32,
}, num_stages=2, num_warps=8), "BLOCK_SIZE_K": 32,
triton.Config({ "GROUP_SIZE_M": 8,
'BLOCK_SIZE_M': 64, },
'BLOCK_SIZE_N': 64, num_stages=4,
'BLOCK_SIZE_K': 64, num_warps=4,
'GROUP_SIZE_M': 8 ),
}, num_stages=3, num_warps=8), triton.Config(
triton.Config({ {
'BLOCK_SIZE_M': 32, "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_N": 32,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_K": 64,
'GROUP_SIZE_M': 8 "GROUP_SIZE_M": 8,
}, num_stages=2, num_warps=4), },
], num_stages=4,
key=['M', 'N', 'K'], num_warps=4,
nearest_power_of_two=True) ),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
)
@triton.jit @triton.jit
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, def transpose_matmul_248_kernel(
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16 A is of shape (M, N) float16
@ -198,7 +319,7 @@ try:
C is of shape (M, K) float16 C is of shape (M, K) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
@ -216,16 +337,25 @@ try:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N) offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) a_ptrs = a_ptr + (
a_mask = (offs_am[:, None] < M) offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk
+ offs_n[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros zeros_ptrs = (
zeros_ptr
+ (offs_n[None, :] // infearure_per_bits)
+ g_idx[:, None] * stride_zeros
)
shifter = (offs_bk % infearure_per_bits) * bits shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits
@ -237,9 +367,9 @@ try:
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
@ -251,36 +381,84 @@ try:
a_ptrs += BLOCK_SIZE_N a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
except: except:
print('triton not installed.') print("triton not installed.")
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device): with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) output = torch.empty(
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), )
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device): with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) output = torch.empty(
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) (input.shape[0], output_dim), device=input.device, dtype=torch.float16
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), )
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(output_dim, META["BLOCK_SIZE_K"]),
)
transpose_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
output_dim,
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output return output
class QuantLinearFunction(torch.autograd.Function): class QuantLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
@ -297,7 +475,9 @@ class QuantLinearFunction(torch.autograd.Function):
grad_input = None grad_input = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) grad_input = transpose_matmul248(
grad_output, qweight, scales, qzeros, g_idx, bits, maxq
)
return grad_input, None, None, None, None, None, None return grad_input, None, None, None, None, None, None
@ -318,8 +498,41 @@ class QuantLinear(nn.Module):
self.outfeatures = qweight.shape[1] self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // 4 self.infeatures = qweight.shape[0] * 32 // 4
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros(
(infeatures // 32 * self.bits, outfeatures), dtype=torch.int32
)
qzeros = torch.zeros(
(math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // self.groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures, ) out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -4,26 +4,36 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import math import math
import os
from texttable import Texttable from texttable import Texttable
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import transformers import transformers
import numpy as np import numpy as np
import torch import torch
from text_generation_server.utils.gptq.quant_linear import QuantLinear
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
class Quantizer(nn.Module): class Quantizer(nn.Module):
def __init__(self, shape=1): def __init__(self, shape=1):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0)) self.register_buffer("maxq", torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape)) self.register_buffer("scale", torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1) self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel self.perchannel = perchannel
self.sym = sym self.sym = sym
@ -84,14 +94,16 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale) self.zero = torch.round(-xmin / self.scale)
if self.mse: if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev) best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)): for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid p = 1 - i / self.grid
xmin1 = p * xmin xmin1 = p * xmin
xmax1 = p * xmax xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q = self._quantize(
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
)
q -= x q -= x
q.abs_() q.abs_()
q.pow_(self.norm) q.pow_(self.norm)
@ -138,7 +150,6 @@ class Quantizer(nn.Module):
class GPTQ: class GPTQ:
def __init__(self, layer, observe=False): def __init__(self, layer, observe=False):
self.layer = layer self.layer = layer
self.dev = self.layer.weight.device self.dev = self.layer.weight.device
@ -166,12 +177,19 @@ class GPTQ:
if len(inp.shape) == 2: if len(inp.shape) == 2:
inp = inp.unsqueeze(0) inp = inp.unsqueeze(0)
tmp = inp.shape[0] tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if isinstance(self.layer, nn.Linear) or isinstance(
self.layer, transformers.Conv1D
):
if len(inp.shape) == 3: if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1])) inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t() inp = inp.t()
if isinstance(self.layer, nn.Conv2d): if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp) inp = unfold(inp)
inp = inp.permute([1, 0, 2]) inp = inp.permute([1, 0, 2])
inp = inp.flatten(1) inp = inp.flatten(1)
@ -184,12 +202,14 @@ class GPTQ:
def print_loss(self, name, q_weight, weight_error, timecost): def print_loss(self, name, q_weight, weight_error, timecost):
table = Texttable() table = Texttable()
name += ' ' * (16 - len(name)) name += " " * (16 - len(name))
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
# assign weight # assign weight
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)
if self.inp1 is not None: if self.inp1 is not None:
# quantize input to int8 # quantize input to int8
@ -203,13 +223,15 @@ class GPTQ:
q_SNR = torch_snr_error(q_out, self.out1).item() q_SNR = torch_snr_error(q_out, self.out1).item()
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
else: else:
q_SNR = '-' q_SNR = "-"
fp_SNR = '-' fp_SNR = "-"
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
print(table.draw().split('\n')[-2]) print(table.draw().split("\n")[-2])
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
):
self.layer.to(self.dev) self.layer.to(self.dev)
W = self.layer.weight.data.clone() W = self.layer.weight.data.clone()
@ -268,7 +290,9 @@ class GPTQ:
if groupsize != -1: if groupsize != -1:
if (i1 + i) % groupsize == 0: if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)
if ((i1 + i) // groupsize) - now_idx == -1: if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale) scale.append(self.quantizer.scale)
@ -277,7 +301,7 @@ class GPTQ:
q = self.quantizer.quantize(w.unsqueeze(1)).flatten() q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2 Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
@ -302,7 +326,9 @@ class GPTQ:
if isinstance(self.layer, transformers.Conv1D): if isinstance(self.layer, transformers.Conv1D):
Q = Q.t() Q = Q.t()
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) self.print_loss(
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
)
if scale == []: if scale == []:
scale.append(self.quantizer.scale) scale.append(self.quantizer.scale)
@ -322,15 +348,18 @@ class GPTQ:
def get_wikitext2(nsamples, seed, seqlen, model_id): def get_wikitext2(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -345,18 +374,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
def get_ptb(nsamples, seed, seqlen, model_id): def get_ptb(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -371,22 +403,37 @@ def get_ptb(nsamples, seed, seqlen, model_id):
def get_c4(nsamples, seed, seqlen, model_id): def get_c4(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
use_auth_token=False,
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
use_auth_token=False,
)
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
while True: while True:
i = random.randint(0, len(traindata) - 1) i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen: if trainenc.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -397,12 +444,13 @@ def get_c4(nsamples, seed, seqlen, model_id):
trainloader.append((inp, tar)) trainloader.append((inp, tar))
import random import random
random.seed(0) random.seed(0)
valenc = [] valenc = []
for _ in range(256): for _ in range(256):
while True: while True:
i = random.randint(0, len(valdata) - 1) i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]['text'], return_tensors='pt') tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
if tmp.input_ids.shape[1] >= seqlen: if tmp.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
@ -411,7 +459,6 @@ def get_c4(nsamples, seed, seqlen, model_id):
valenc = torch.hstack(valenc) valenc = torch.hstack(valenc)
class TokenizerWrapper: class TokenizerWrapper:
def __init__(self, input_ids): def __init__(self, input_ids):
self.input_ids = input_ids self.input_ids = input_ids
@ -422,18 +469,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
def get_ptb_new(nsamples, seed, seqlen, model_id): def get_ptb_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -448,22 +498,35 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
def get_c4_new(nsamples, seed, seqlen, model_id): def get_c4_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
)
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
while True: while True:
i = random.randint(0, len(traindata) - 1) i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen: if trainenc.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -473,11 +536,10 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
tar[:, :-1] = -100 tar[:, :-1] = -100
trainloader.append((inp, tar)) trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
valenc = valenc.input_ids[:, :(256 * seqlen)] valenc = valenc.input_ids[:, : (256 * seqlen)]
class TokenizerWrapper: class TokenizerWrapper:
def __init__(self, input_ids): def __init__(self, input_ids):
self.input_ids = input_ids self.input_ids = input_ids
@ -486,31 +548,46 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
return trainloader, valenc return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''): def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
if 'wikitext2' in name: if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id) return get_wikitext2(nsamples, seed, seqlen, model_id)
if 'ptb' in name: if "ptb" in name:
if 'new' in name: if "new" in name:
return get_ptb_new(nsamples, seed, seqlen, model_id) return get_ptb_new(nsamples, seed, seqlen, model_id)
return get_ptb(nsamples, seed, seqlen, model_id) return get_ptb(nsamples, seed, seqlen, model_id)
if 'c4' in name: if "c4" in name:
if 'new' in name: if "new" in name:
return get_c4_new(nsamples, seed, seqlen, model_id) return get_c4_new(nsamples, seed, seqlen, model_id)
return get_c4(nsamples, seed, seqlen, model_id) return get_c4(nsamples, seed, seqlen, model_id)
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
# Skip last lm_head linear # Skip last lm_head linear
if type(module) in layers and "lm_head" not in name: if type(module) in layers and "lm_head" not in name:
return {name: module} return {name: module}
res = {} res = {}
for name1, child in module.named_children(): for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res return res
@torch.no_grad() @torch.no_grad()
def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False): def sequential(
print('Starting ...') model,
dataloader,
dev,
nsamples,
bits,
groupsize,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
):
print("Starting ...")
use_cache = model.config.use_cache use_cache = model.config.use_cache
model.config.use_cache = False model.config.use_cache = False
@ -524,20 +601,21 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
# layers[0] = layers[0].to(dev) # layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype dtype = next(iter(model.parameters())).dtype
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) inps = torch.zeros(
cache = {'i': 0, 'attention_mask': None} (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0, "attention_mask": None}
class Catcher(nn.Module): class Catcher(nn.Module):
def __init__(self, module): def __init__(self, module):
super().__init__() super().__init__()
self.module = module self.module = module
def forward(self, inp, **kwargs): def forward(self, inp, **kwargs):
inps[cache['i']] = inp inps[cache["i"]] = inp
cache['i'] += 1 cache["i"] += 1
cache['attention_mask'] = kwargs['attention_mask'] cache["attention_mask"] = kwargs["attention_mask"]
cache['position_ids'] = kwargs['position_ids'] cache["position_ids"] = kwargs["position_ids"]
raise ValueError raise ValueError
layers[0] = Catcher(layers[0]) layers[0] = Catcher(layers[0])
@ -554,20 +632,20 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
torch.cuda.empty_cache() torch.cuda.empty_cache()
outs = torch.zeros_like(inps) outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask'].to(dev) attention_mask = cache["attention_mask"].to(dev)
position_ids = cache['position_ids'].to(dev) position_ids = cache["position_ids"].to(dev)
print('Ready.') print("Ready.")
quantizers = {} quantizers = {}
for i in range(len(layers)): for i in range(len(layers)):
print(f"Quantizing layer {i+1}/{len(layers)}..")
print(f'Quantizing layer {i+1}/{len(layers)}..') print("+------------------+--------------+------------+-----------+-------+")
print('+------------------+--------------+------------+-----------+-------+') print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') print("+==================+==============+============+===========+=======+")
print('+==================+==============+============+===========+=======+')
from accelerate.hooks import remove_hook_from_submodules from accelerate.hooks import remove_hook_from_submodules
layer = layers[i].to(dev) layer = layers[i].to(dev)
remove_hook_from_submodules(layer) remove_hook_from_submodules(layer)
full = find_layers(layer) full = find_layers(layer)
@ -578,10 +656,11 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
gptq = {} gptq = {}
for name in subset: for name in subset:
gptq[name] = GPTQ(subset[name]) gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False) gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
def add_batch(name): def add_batch(name):
def tmp(_, inp, out): def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data) gptq[name].add_batch(inp[0].data, out.data)
@ -591,19 +670,38 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
for name in subset: for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name))) handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
for h in handles: for h in handles:
h.remove() h.remove()
for name in subset: for name in subset:
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name) scale, zero, g_idx, error = gptq[name].fasterquant(
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) percdamp=percdamp,
groupsize=groupsize,
actorder=act_order,
name=name,
)
quantizers["model.layers.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
bits,
groupsize,
)
gptq[name].free() gptq[name].free()
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] outs[j] = layer(
inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
layers[i] = layer.cpu() layers[i] = layer.cpu()
del layer del layer
@ -611,12 +709,12 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
torch.cuda.empty_cache() torch.cuda.empty_cache()
inps, outs = outs, inps inps, outs = outs, inps
print('+------------------+--------------+------------+-----------+-------+') print("+------------------+--------------+------------+-----------+-------+")
print('\n') print("\n")
# if args.observe: # if args.observe:
# observer.print() # observer.print()
# conditions = gen_conditions(args.wbits, args.groupsize) # conditions = gen_conditions(args.bits, args.groupsize)
# for item in observer.items(): # for item in observer.items():
# name = item[0] # name = item[0]
# layerid = item[1] # layerid = item[1]
@ -625,23 +723,23 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
# target = error / 2 # target = error / 2
# table = Texttable() # table = Texttable()
# table.header(['wbits', 'groupsize', 'error']) # table.header(['bits', 'groupsize', 'error'])
# table.set_cols_dtype(['i', 'i', 'f']) # table.set_cols_dtype(['i', 'i', 'f'])
# table.add_row([args.wbits, args.groupsize, error]) # table.add_row([args.bits, args.groupsize, error])
# print('Optimizing {} {} ..'.format(name, layerid)) # print('Optimizing {} {} ..'.format(name, layerid))
# for wbits, groupsize in conditions: # for bits, groupsize in conditions:
# if error < target: # if error < target:
# # if error dropped 50%, skip # # if error dropped 50%, skip
# break # break
# gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) # gptq.quantizer.configure(bits, perchannel=True, sym=args.sym, mse=False)
# scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) # scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
# table.add_row([wbits, groupsize, error]) # table.add_row([bits, groupsize, error])
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) # quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
# print(table.draw()) # print(table.draw())
# print('\n') # print('\n')
@ -656,34 +754,34 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
# @torch.no_grad() # @torch.no_grad()
# def llama_eval(model, testenc, dev): # def llama_eval(model, testenc, dev):
# print('Evaluating ...') # print('Evaluating ...')
# #
# testenc = testenc.input_ids # testenc = testenc.input_ids
# nsamples = testenc.numel() // model.seqlen # nsamples = testenc.numel() // model.seqlen
# #
# use_cache = model.config.use_cache # use_cache = model.config.use_cache
# model.config.use_cache = False # model.config.use_cache = False
# layers = model.model.layers # layers = model.model.layers
# #
# model.model.embed_tokens = model.model.embed_tokens.to(dev) # model.model.embed_tokens = model.model.embed_tokens.to(dev)
# layers[0] = layers[0].to(dev) # layers[0] = layers[0].to(dev)
# #
# dtype = next(iter(model.parameters())).dtype # dtype = next(iter(model.parameters())).dtype
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) # inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
# cache = {'i': 0, 'attention_mask': None} # cache = {'i': 0, 'attention_mask': None}
# #
# class Catcher(nn.Module): # class Catcher(nn.Module):
# #
# def __init__(self, module): # def __init__(self, module):
# super().__init__() # super().__init__()
# self.module = module # self.module = module
# #
# def forward(self, inp, **kwargs): # def forward(self, inp, **kwargs):
# inps[cache['i']] = inp # inps[cache['i']] = inp
# cache['i'] += 1 # cache['i'] += 1
# cache['attention_mask'] = kwargs['attention_mask'] # cache['attention_mask'] = kwargs['attention_mask']
# cache['position_ids'] = kwargs['position_ids'] # cache['position_ids'] = kwargs['position_ids']
# raise ValueError # raise ValueError
# #
# layers[0] = Catcher(layers[0]) # layers[0] = Catcher(layers[0])
# for i in range(nsamples): # for i in range(nsamples):
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) # batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
@ -692,39 +790,39 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
# except ValueError: # except ValueError:
# pass # pass
# layers[0] = layers[0].module # layers[0] = layers[0].module
# #
# layers[0] = layers[0].cpu() # layers[0] = layers[0].cpu()
# model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu()
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
# #
# outs = torch.zeros_like(inps) # outs = torch.zeros_like(inps)
# attention_mask = cache['attention_mask'] # attention_mask = cache['attention_mask']
# position_ids = cache['position_ids'] # position_ids = cache['position_ids']
# #
# for i in range(len(layers)): # for i in range(len(layers)):
# print(i) # print(i)
# layer = layers[i].to(dev) # layer = layers[i].to(dev)
# #
# if args.nearest: # if args.nearest:
# subset = find_layers(layer) # subset = find_layers(layer)
# for name in subset: # for name in subset:
# quantizer = quant.Quantizer() # quantizer = quant.Quantizer()
# quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) # quantizer.configure(args.bits, perchannel=True, sym=args.sym, mse=False)
# W = subset[name].weight.data # W = subset[name].weight.data
# quantizer.find_params(W, weight=True) # quantizer.find_params(W, weight=True)
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) # subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
# #
# for j in range(nsamples): # for j in range(nsamples):
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] # outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
# layers[i] = layer.cpu() # layers[i] = layer.cpu()
# del layer # del layer
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
# inps, outs = outs, inps # inps, outs = outs, inps
# #
# if model.model.norm is not None: # if model.model.norm is not None:
# model.model.norm = model.model.norm.to(dev) # model.model.norm = model.model.norm.to(dev)
# model.lm_head = model.lm_head.to(dev) # model.lm_head = model.lm_head.to(dev)
# #
# testenc = testenc.to(dev) # testenc = testenc.to(dev)
# nlls = [] # nlls = []
# for i in range(nsamples): # for i in range(nsamples):
@ -740,36 +838,61 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01
# nlls.append(neg_log_likelihood) # nlls.append(neg_log_likelihood)
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) # ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
# print(ppl.item()) # print(ppl.item())
# #
# model.config.use_cache = use_cache # model.config.use_cache = use_cache
def make_quant_linear(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + "." + attr if name != "" else attr
if name1 in names:
delattr(module, attr)
setattr(
module,
attr,
QuantLinear.new(
bits,
groupsize,
tmp.in_features,
tmp.out_features,
tmp.bias is not None,
),
)
for name1, child in module.named_children():
make_quant_linear(
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
)
# TODO: perform packing on GPU # TODO: perform packing on GPU
def pack(model, quantizers, wbits, groupsize): def pack(model, quantizers, bits, groupsize):
layers = find_layers(model) layers = find_layers(model)
layers = {n: layers[n] for n in quantizers} layers = {n: layers[n] for n in quantizers}
quant.make_quant_linear(model, quantizers, wbits, groupsize) make_quant_linear(model, quantizers, bits, groupsize)
qlayers = find_layers(model, [QuantLinear]) qlayers = find_layers(model, [QuantLinear])
print('Packing ...') print("Packing ...")
for name in qlayers: for name in qlayers:
print(name) print(name)
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx) qlayers[name].pack(layers[name], scale, zero, g_idx)
print('Done.') print("Done.")
return model return model
# def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): # def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils # from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
# config = LlamaConfig.from_pretrained(model) # config = LlamaConfig.from_pretrained(model)
# #
# def noop(*args, **kwargs): # def noop(*args, **kwargs):
# pass # pass
# #
# torch.nn.init.kaiming_uniform_ = noop # torch.nn.init.kaiming_uniform_ = noop
# torch.nn.init.uniform_ = noop # torch.nn.init.uniform_ = noop
# torch.nn.init.normal_ = noop # torch.nn.init.normal_ = noop
# #
# torch.set_default_dtype(torch.half) # torch.set_default_dtype(torch.half)
# modeling_utils._init_weights = False # modeling_utils._init_weights = False
# torch.set_default_dtype(torch.half) # torch.set_default_dtype(torch.half)
@ -781,30 +904,30 @@ def pack(model, quantizers, wbits, groupsize):
# for name in ['lm_head']: # for name in ['lm_head']:
# if name in layers: # if name in layers:
# del layers[name] # del layers[name]
# quant.make_quant_linear(model, layers, wbits, groupsize) # quant.make_quant_linear(model, layers, bits, groupsize)
# #
# del layers # del layers
# #
# print('Loading model ...') # print('Loading model ...')
# if checkpoint.endswith('.safetensors'): # if checkpoint.endswith('.safetensors'):
# from safetensors.torch import load_file as safe_load # from safetensors.torch import load_file as safe_load
# model.load_state_dict(safe_load(checkpoint)) # model.load_state_dict(safe_load(checkpoint))
# else: # else:
# model.load_state_dict(torch.load(checkpoint)) # model.load_state_dict(torch.load(checkpoint))
# #
# if eval: # if eval:
# quant.make_quant_attn(model) # quant.make_quant_attn(model)
# quant.make_quant_norm(model) # quant.make_quant_norm(model)
# if fused_mlp: # if fused_mlp:
# quant.make_fused_mlp(model) # quant.make_fused_mlp(model)
# #
# if warmup_autotune: # if warmup_autotune:
# quant.autotune_warmup_linear(model, transpose=not (eval)) # quant.autotune_warmup_linear(model, transpose=not (eval))
# if eval and fused_mlp: # if eval and fused_mlp:
# quant.autotune_warmup_fused(model) # quant.autotune_warmup_fused(model)
# model.seqlen = 2048 # model.seqlen = 2048
# print('Done.') # print('Done.')
# #
# return model # return model
@ -814,33 +937,33 @@ def pack(model, quantizers, wbits, groupsize):
# model.model.norm = model.model.norm.to(gpus[0]) # model.model.norm = model.model.norm.to(gpus[0])
# import copy # import copy
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0]) # model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
# #
# cache = {'mask': None, 'position_ids': None} # cache = {'mask': None, 'position_ids': None}
# #
# class MoveModule(nn.Module): # class MoveModule(nn.Module):
# #
# def __init__(self, module, invalidate_cache): # def __init__(self, module, invalidate_cache):
# super().__init__() # super().__init__()
# self.module = module # self.module = module
# self.dev = next(iter(self.module.parameters())).device # self.dev = next(iter(self.module.parameters())).device
# self.invalidate_cache=invalidate_cache # self.invalidate_cache=invalidate_cache
# #
# def forward(self, *inp, **kwargs): # def forward(self, *inp, **kwargs):
# inp = list(inp) # inp = list(inp)
# if inp[0].device != self.dev: # if inp[0].device != self.dev:
# inp[0] = inp[0].to(self.dev) # inp[0] = inp[0].to(self.dev)
# #
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache: # if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
# cache['mask'] = kwargs['attention_mask'].to(self.dev) # cache['mask'] = kwargs['attention_mask'].to(self.dev)
# kwargs['attention_mask'] = cache['mask'] # kwargs['attention_mask'] = cache['mask']
# #
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache: # if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
# cache['position_ids'] = kwargs['position_ids'].to(self.dev) # cache['position_ids'] = kwargs['position_ids'].to(self.dev)
# kwargs['position_ids'] = cache['position_ids'] # kwargs['position_ids'] = cache['position_ids']
# #
# tmp = self.module(*inp, **kwargs) # tmp = self.module(*inp, **kwargs)
# return tmp # return tmp
# #
# layers = model.model.layers # layers = model.model.layers
# from math import ceil # from math import ceil
# if not gpu_dist: # if not gpu_dist:
@ -852,49 +975,49 @@ def pack(model, quantizers, wbits, groupsize):
# assigned_gpus = [0] * (gpu_dist[0]-1) # assigned_gpus = [0] * (gpu_dist[0]-1)
# for i in range(1, len(gpu_dist)): # for i in range(1, len(gpu_dist)):
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i] # assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
# #
# remaining_assignments = len(layers)-len(assigned_gpus) - 1 # remaining_assignments = len(layers)-len(assigned_gpus) - 1
# if remaining_assignments > 0: # if remaining_assignments > 0:
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments # assigned_gpus = assigned_gpus + [-1] * remaining_assignments
# #
# assigned_gpus = assigned_gpus + [0] # assigned_gpus = assigned_gpus + [0]
# #
# for i in range(len(layers)): # for i in range(len(layers)):
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0) # layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
# #
# model.gpus = gpus # model.gpus = gpus
# #
# #
# def benchmark(model, input_ids, check=False): # def benchmark(model, input_ids, check=False):
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) # input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# #
# cache = {'past': None} # cache = {'past': None}
# #
# def clear_past(i): # def clear_past(i):
# #
# def tmp(layer, inp, out): # def tmp(layer, inp, out):
# if cache['past']: # if cache['past']:
# cache['past'][i] = None # cache['past'][i] = None
# #
# return tmp # return tmp
# #
# for i, layer in enumerate(model.model.layers): # for i, layer in enumerate(model.model.layers):
# layer.register_forward_hook(clear_past(i)) # layer.register_forward_hook(clear_past(i))
# #
# print('Benchmarking ...') # print('Benchmarking ...')
# #
# if check: # if check:
# loss = nn.CrossEntropyLoss() # loss = nn.CrossEntropyLoss()
# tot = 0. # tot = 0.
# #
# def sync(): # def sync():
# if hasattr(model, 'gpus'): # if hasattr(model, 'gpus'):
# for gpu in model.gpus: # for gpu in model.gpus:
# torch.cuda.synchronize(gpu) # torch.cuda.synchronize(gpu)
# else: # else:
# torch.cuda.synchronize() # torch.cuda.synchronize()
# #
# max_memory = 0 # max_memory = 0
# with torch.no_grad(): # with torch.no_grad():
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV) # attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
@ -921,9 +1044,11 @@ def pack(model, quantizers, wbits, groupsize):
# print('max memory(MiB):', max_memory) # print('max memory(MiB):', max_memory)
def quantize(model_id: str, wbits: int, groupsize: int): def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0") model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="balanced_low_0"
)
print("LOADED model") print("LOADED model")
model.seqlen = 2048 model.seqlen = 2048
@ -931,11 +1056,12 @@ def quantize(model_id: str, wbits: int, groupsize: int):
nsamples = 128 nsamples = 128
seed = None seed = None
dataloader, testloader = get_loaders(
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen) dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
)
tick = time.time() tick = time.time()
quantizers = sequential(model, dataloader, DEV, nsamples, wbits, groupsize) quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
print(time.time() - tick) print(time.time() - tick)
# if args.benchmark: # if args.benchmark:
@ -956,7 +1082,7 @@ def quantize(model_id: str, wbits: int, groupsize: int):
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) # dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
# print(dataset) # print(dataset)
# llama_eval(model, testloader, DEV) # llama_eval(model, testloader, DEV)
# #
# if args.test_generation: # if args.test_generation:
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] # gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
# if len(gpus) > 1: # if len(gpus) > 1:
@ -970,20 +1096,57 @@ def quantize(model_id: str, wbits: int, groupsize: int):
# streamer = TextStreamer(tokenizer) # streamer = TextStreamer(tokenizer)
# with torch.no_grad(): # with torch.no_grad():
# generated_ids = model.generate(input_ids, streamer=streamer) # generated_ids = model.generate(input_ids, streamer=streamer)
# #
# if args.quant_directory is not None: # if args.quant_directory is not None:
# export_quant_table(quantizers, args.quant_directory) # export_quant_table(quantizers, args.quant_directory)
# if not args.observe and args.save: # if not args.observe and args.save:
# llama_pack(model, quantizers, args.wbits, args.groupsize) # llama_pack(model, quantizers, args.bits, args.groupsize)
# torch.save(model.state_dict(), args.save) # torch.save(model.state_dict(), args.save)
# if not args.observe and args.save_safetensors: # if not args.observe and args.save_safetensors:
pack(model, quantizers, wbits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file as safe_save from safetensors.torch import save_file
state_dict = model.state_dict() from transformers.modeling_utils import shard_checkpoint
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
safe_save(state_dict, args.save_safetensors)
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor(bits)
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
shards, index = shard_checkpoint(
state_dict, max_shard_size="10GB", weights_name="model.safetensors"
)
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(
shard,
os.path.join(output_dir, shard_file),
metadata={
"format": "pt",
"quantized": "gptq",
"origin": "text-generation-inference",
},
)
if index is None:
path_to_weights = os.path.join(save_directory, "model.safetensors")
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = "model.safetensors.index.json"
save_index_file = os.path.join(save_directory, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
config = AutoConfig.from_pretrained(model_id)
config.save_pretrained(output_dir)
logger.info("Saved config")
logger.info("Saving tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(output_dir)
logger.info("Saved tokenizer")

View File

@ -134,13 +134,15 @@ def get_linear(weight, bias, quantize):
try: try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception: except Exception:
raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.") raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = QuantLinear( linear = QuantLinear(
qweight, qweight,
qzeros, qzeros,
scales, scales,
g_idx, g_idx,
bias, bias,
bits, bits,
groupsize, groupsize,

View File

@ -86,19 +86,27 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str): def get_multi_weights_col(self, prefixes: List[str], quantize: str):
if quantize == "gptq": if quantize == "gptq":
try: try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) qzeros = torch.cat(
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
# TODO Get that from file to be more generic
bits = 4 bits = self.get_tensor("gptq_bits").item()
groupsize = 128 groupsize = self.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize) weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -110,14 +118,15 @@ class Weights:
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
# TODO Get that from file to be more generic bits = self.get_tensor("gptq_bits").item()
bits = 4 groupsize = self.get_tensor("gptq_groupsize").item()
groupsize = 128
weight = (qweight, qzeros, scales, g_idx, bits, groupsize) weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else: else: