diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 69d6417b..b4539c46 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -150,14 +150,16 @@ def download_weights( # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files) + @app.command() def quantize( model_id: str, + output_dir: str, revision: Optional[str] = None, logger_level: str = "INFO", json_output: bool = False, ): - extension: str = ".safetensors", + extension: str = (".safetensors",) # Remove default handler logger.remove() logger.add( @@ -169,12 +171,15 @@ def quantize( backtrace=True, diagnose=False, ) - download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output) + download_weights( + model_id=model_id, + revision=revision, + logger_level=logger_level, + json_output=json_output, + ) from text_generation_server.utils.gptq.quantize import quantize - quantize(model_id=model_id, wbits=4, groupsize=128) - - + quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir) if __name__ == "__main__": diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 65aa2a4b..edf8d8ad 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -248,7 +248,9 @@ def get_model( if sharded: raise ValueError("sharded is not supported for AutoModel") if quantize == "gptq": - raise ValueError("gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise ValueError( + "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 51348791..63dbedb7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -59,7 +59,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_multi_weights_col([prefix], quantize=config.quantize) - if isinstance(weight, torch.Tensor): + if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = ( weight.view( @@ -75,7 +75,6 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) if config.use_parallel_residual: return linear diff --git a/server/text_generation_server/utils/gptq/custom_autotune.py b/server/text_generation_server/utils/gptq/custom_autotune.py index 875c832e..17dff02e 100644 --- a/server/text_generation_server/utils/gptq/custom_autotune.py +++ b/server/text_generation_server/utils/gptq/custom_autotune.py @@ -1,4 +1,4 @@ -#https://github.com/fpgaminer/GPTQ-triton +# https://github.com/fpgaminer/GPTQ-triton """ Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. """ @@ -12,15 +12,23 @@ import triton class Autotuner(triton.KernelInterface): - - def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): - ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results - ''' + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + """ if not configs: self.configs = [triton.Config({}, num_warps=4, num_stages=2)] else: @@ -41,9 +49,12 @@ class Autotuner(triton.KernelInterface): self.arg_names = arg_names # prune configs if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] else: perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k @@ -55,8 +66,10 @@ class Autotuner(triton.KernelInterface): # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) @@ -64,14 +77,21 @@ class Autotuner(triton.KernelInterface): if config.pre_hook: config.pre_hook(self.nargs) self.hook(args) - self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) try: # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) + return triton.testing.do_bench( + kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40 + ) except triton.compiler.OutOfResources: - return (float('inf'), float('inf'), float('inf')) + return (float("inf"), float("inf"), float("inf")) def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) @@ -81,13 +101,16 @@ class Autotuner(triton.KernelInterface): # This reduces the amount of autotuning by rounding the keys to the nearest power of two # In my testing this gives decent results, and greatly reduces the amount of tuning required if self.nearest_power_of_two: - key = tuple([2**int(math.log2(x) + 0.5) for x in key]) + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) if key not in self.cache: # prune configs pruned_configs = self.prune_configs(kwargs) bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } bench_end = time.time() self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) @@ -99,7 +122,13 @@ class Autotuner(triton.KernelInterface): self.best_config = config if config.pre_hook is not None: config.pre_hook(self.nargs) - return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) def prune_configs(self, kwargs): pruned_configs = self.configs @@ -110,8 +139,19 @@ class Autotuner(triton.KernelInterface): if isinstance(top_k, float) and top_k <= 1.0: top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] return pruned_configs def warmup(self, *args, **kwargs): @@ -127,39 +167,49 @@ class Autotuner(triton.KernelInterface): self.nargs = None -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): +def autotune( + configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python - .. code-block:: python - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ def decorator(fn): - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + ) return decorator @@ -168,26 +218,44 @@ def matmul248_kernel_config_pruner(configs, nargs): """ The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. """ - m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) - n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) - k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) used = set() for config in configs: - block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) - block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) - block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) - group_size_m = config.kwargs['GROUP_SIZE_M'] + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] - if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: continue - used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) - yield triton.Config({ - 'BLOCK_SIZE_M': block_size_m, - 'BLOCK_SIZE_N': block_size_n, - 'BLOCK_SIZE_K': block_size_k, - 'GROUP_SIZE_M': group_size_m - }, - num_stages=config.num_stages, - num_warps=config.num_warps) + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 6e939115..cd3c4d35 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -12,66 +12,121 @@ try: # code based https://github.com/fpgaminer/GPTQ-triton @custom_autotune.autotune( configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), ], - key=['M', 'N', 'K'], + key=["M", "N", "K"], nearest_power_of_two=True, prune_configs_by={ - 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, - 'perf_model': None, - 'top_k': None, + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, }, ) @triton.jit - def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 @@ -79,7 +134,7 @@ try: C is of shape (M, N) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -97,10 +152,15 @@ try: offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_bn[None, :] @@ -114,13 +174,17 @@ try: g_idx = tl.load(g_ptrs) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) + zeros = zeros + 1 - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values @@ -136,61 +200,118 @@ try: c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) - @custom_autotune.autotune(configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 256, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True) + @custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + ) @triton.jit - def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, - stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + def transpose_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): """ Compute the matrix multiplication C = A x B. A is of shape (M, N) float16 @@ -198,7 +319,7 @@ try: C is of shape (M, K) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -216,16 +337,25 @@ try: offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_n = tl.arange(0, BLOCK_SIZE_N) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - a_mask = (offs_am[:, None] < M) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = offs_am[:, None] < M # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + ( + (offs_bk[:, None] // infearure_per_bits) * stride_bk + + offs_n[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_bk g_idx = tl.load(g_ptrs) # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros + zeros_ptrs = ( + zeros_ptr + + (offs_n[None, :] // infearure_per_bits) + + g_idx[:, None] * stride_zeros + ) shifter = (offs_bk % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits @@ -237,9 +367,9 @@ try: zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) + zeros = zeros + 1 - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values @@ -251,36 +381,84 @@ try: a_ptrs += BLOCK_SIZE_N b_ptrs += BLOCK_SIZE_N scales_ptrs += BLOCK_SIZE_N - zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) tl.store(c_ptrs, accumulator, mask=c_mask) + except: - print('triton not installed.') + print("triton not installed.") def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): with torch.cuda.device(input.device): - output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) - matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) return output def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): with torch.cuda.device(input.device): output_dim = (qweight.shape[0] * 32) // bits - output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) - transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = torch.empty( + (input.shape[0], output_dim), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]), + ) + transpose_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + output_dim, + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) return output class QuantLinearFunction(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): @@ -297,7 +475,9 @@ class QuantLinearFunction(torch.autograd.Function): grad_input = None if ctx.needs_input_grad[0]: - grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + grad_input = transpose_matmul248( + grad_output, qweight, scales, qzeros, g_idx, bits, maxq + ) return grad_input, None, None, None, None, None, None @@ -318,8 +498,41 @@ class QuantLinear(nn.Module): self.outfeatures = qweight.shape[1] self.infeatures = qweight.shape[0] * 32 // 4 + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros( + (infeatures // 32 * self.bits, outfeatures), dtype=torch.int32 + ) + qzeros = torch.zeros( + (math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // self.groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures, ) - out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index f7436f88..a86d518e 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -4,26 +4,36 @@ import numpy as np import torch import torch.nn as nn import math +import os from texttable import Texttable from transformers import AutoModelForCausalLM import transformers import numpy as np import torch +from text_generation_server.utils.gptq.quant_linear import QuantLinear DEV = torch.device("cuda:0") class Quantizer(nn.Module): - def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) - - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + def configure( + self, + bits, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + trits=False, + ): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -84,14 +94,16 @@ class Quantizer(nn.Module): self.zero = torch.round(-xmin / self.scale) if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q = self._quantize( + x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq + ) q -= x q.abs_() q.pow_(self.norm) @@ -138,7 +150,6 @@ class Quantizer(nn.Module): class GPTQ: - def __init__(self, layer, observe=False): self.layer = layer self.dev = self.layer.weight.device @@ -166,12 +177,19 @@ class GPTQ: if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) inp = unfold(inp) inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) @@ -184,12 +202,14 @@ class GPTQ: def print_loss(self, name, q_weight, weight_error, timecost): table = Texttable() - name += ' ' * (16 - len(name)) + name += " " * (16 - len(name)) - table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) + table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) # assign weight - self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) if self.inp1 is not None: # quantize input to int8 @@ -203,13 +223,15 @@ class GPTQ: q_SNR = torch_snr_error(q_out, self.out1).item() fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() else: - q_SNR = '-' - fp_SNR = '-' + q_SNR = "-" + fp_SNR = "-" table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) - print(table.draw().split('\n')[-2]) + print(table.draw().split("\n")[-2]) - def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): + def fasterquant( + self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" + ): self.layer.to(self.dev) W = self.layer.weight.data.clone() @@ -268,7 +290,9 @@ class GPTQ: if groupsize != -1: if (i1 + i) % groupsize == 0: - self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + groupsize)], weight=True + ) if ((i1 + i) // groupsize) - now_idx == -1: scale.append(self.quantizer.scale) @@ -277,7 +301,7 @@ class GPTQ: q = self.quantizer.quantize(w.unsqueeze(1)).flatten() Q1[:, i] = q - Losses1[:, i] = (w - q)**2 / d**2 + Losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) @@ -302,7 +326,9 @@ class GPTQ: if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) + self.print_loss( + name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) + ) if scale == []: scale.append(self.quantizer.scale) @@ -322,15 +348,18 @@ class GPTQ: def get_wikitext2(nsamples, seed, seqlen, model_id): from datasets import load_dataset - traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') - testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") import random + random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -345,18 +374,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id): def get_ptb(nsamples, seed, seqlen, model_id): from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') + trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") import random + random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -371,22 +403,37 @@ def get_ptb(nsamples, seed, seqlen, model_id): def get_c4(nsamples, seed, seqlen, model_id): from datasets import load_dataset - traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False) - valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random + random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) @@ -397,12 +444,13 @@ def get_c4(nsamples, seed, seqlen, model_id): trainloader.append((inp, tar)) import random + random.seed(0) valenc = [] for _ in range(256): while True: i = random.randint(0, len(valdata) - 1) - tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + tmp = tokenizer(valdata[i]["text"], return_tensors="pt") if tmp.input_ids.shape[1] >= seqlen: break i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) @@ -411,7 +459,6 @@ def get_c4(nsamples, seed, seqlen, model_id): valenc = torch.hstack(valenc) class TokenizerWrapper: - def __init__(self, input_ids): self.input_ids = input_ids @@ -422,18 +469,21 @@ def get_c4(nsamples, seed, seqlen, model_id): def get_ptb_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") import random + random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -448,22 +498,35 @@ def get_ptb_new(nsamples, seed, seqlen, model_id): def get_c4_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset - traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') - valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random + random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) @@ -473,11 +536,10 @@ def get_c4_new(nsamples, seed, seqlen, model_id): tar[:, :-1] = -100 trainloader.append((inp, tar)) - valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') - valenc = valenc.input_ids[:, :(256 * seqlen)] + valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") + valenc = valenc.input_ids[:, : (256 * seqlen)] class TokenizerWrapper: - def __init__(self, input_ids): self.input_ids = input_ids @@ -486,31 +548,46 @@ def get_c4_new(nsamples, seed, seqlen, model_id): return trainloader, valenc -def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''): - if 'wikitext2' in name: +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""): + if "wikitext2" in name: return get_wikitext2(nsamples, seed, seqlen, model_id) - if 'ptb' in name: - if 'new' in name: + if "ptb" in name: + if "new" in name: return get_ptb_new(nsamples, seed, seqlen, model_id) return get_ptb(nsamples, seed, seqlen, model_id) - if 'c4' in name: - if 'new' in name: + if "c4" in name: + if "new" in name: return get_c4_new(nsamples, seed, seqlen, model_id) return get_c4(nsamples, seed, seqlen, model_id) -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): # Skip last lm_head linear if type(module) in layers and "lm_head" not in name: return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) return res + @torch.no_grad() -def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False): - print('Starting ...') +def sequential( + model, + dataloader, + dev, + nsamples, + bits, + groupsize, + percdamp=0.01, + sym: bool = False, + act_order: bool = False, +): + print("Starting ...") use_cache = model.config.use_cache model.config.use_cache = False @@ -524,20 +601,21 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 # layers[0] = layers[0].to(dev) dtype = next(iter(model.parameters())).dtype - inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) - cache = {'i': 0, 'attention_mask': None} + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {"i": 0, "attention_mask": None} class Catcher(nn.Module): - def __init__(self, module): super().__init__() self.module = module def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - cache['position_ids'] = kwargs['position_ids'] + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] + cache["position_ids"] = kwargs["position_ids"] raise ValueError layers[0] = Catcher(layers[0]) @@ -554,20 +632,20 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 torch.cuda.empty_cache() outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'].to(dev) - position_ids = cache['position_ids'].to(dev) + attention_mask = cache["attention_mask"].to(dev) + position_ids = cache["position_ids"].to(dev) - print('Ready.') + print("Ready.") quantizers = {} for i in range(len(layers)): - - print(f'Quantizing layer {i+1}/{len(layers)}..') - print('+------------------+--------------+------------+-----------+-------+') - print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') - print('+==================+==============+============+===========+=======+') + print(f"Quantizing layer {i+1}/{len(layers)}..") + print("+------------------+--------------+------------+-----------+-------+") + print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") + print("+==================+==============+============+===========+=======+") from accelerate.hooks import remove_hook_from_submodules + layer = layers[i].to(dev) remove_hook_from_submodules(layer) full = find_layers(layer) @@ -578,10 +656,11 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False) + gptq[name].quantizer.configure( + bits, perchannel=True, sym=sym, mse=False + ) def add_batch(name): - def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) @@ -591,19 +670,38 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): - - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] for h in handles: h.remove() for name in subset: - scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name) - quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) + scale, zero, g_idx, error = gptq[name].fasterquant( + percdamp=percdamp, + groupsize=groupsize, + actorder=act_order, + name=name, + ) + quantizers["model.layers.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + bits, + groupsize, + ) gptq[name].free() for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] layers[i] = layer.cpu() del layer @@ -611,12 +709,12 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 torch.cuda.empty_cache() inps, outs = outs, inps - print('+------------------+--------------+------------+-----------+-------+') - print('\n') + print("+------------------+--------------+------------+-----------+-------+") + print("\n") # if args.observe: # observer.print() - # conditions = gen_conditions(args.wbits, args.groupsize) + # conditions = gen_conditions(args.bits, args.groupsize) # for item in observer.items(): # name = item[0] # layerid = item[1] @@ -625,23 +723,23 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 # target = error / 2 # table = Texttable() - # table.header(['wbits', 'groupsize', 'error']) + # table.header(['bits', 'groupsize', 'error']) # table.set_cols_dtype(['i', 'i', 'f']) - # table.add_row([args.wbits, args.groupsize, error]) + # table.add_row([args.bits, args.groupsize, error]) # print('Optimizing {} {} ..'.format(name, layerid)) - # for wbits, groupsize in conditions: + # for bits, groupsize in conditions: # if error < target: # # if error dropped 50%, skip # break - # gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) + # gptq.quantizer.configure(bits, perchannel=True, sym=args.sym, mse=False) # scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) - # table.add_row([wbits, groupsize, error]) - # quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) + # table.add_row([bits, groupsize, error]) + # quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize) # print(table.draw()) # print('\n') @@ -656,34 +754,34 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 # @torch.no_grad() # def llama_eval(model, testenc, dev): # print('Evaluating ...') -# +# # testenc = testenc.input_ids # nsamples = testenc.numel() // model.seqlen -# +# # use_cache = model.config.use_cache # model.config.use_cache = False # layers = model.model.layers -# +# # model.model.embed_tokens = model.model.embed_tokens.to(dev) # layers[0] = layers[0].to(dev) -# +# # dtype = next(iter(model.parameters())).dtype # inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) # cache = {'i': 0, 'attention_mask': None} -# +# # class Catcher(nn.Module): -# +# # def __init__(self, module): # super().__init__() # self.module = module -# +# # def forward(self, inp, **kwargs): # inps[cache['i']] = inp # cache['i'] += 1 # cache['attention_mask'] = kwargs['attention_mask'] # cache['position_ids'] = kwargs['position_ids'] # raise ValueError -# +# # layers[0] = Catcher(layers[0]) # for i in range(nsamples): # batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) @@ -692,39 +790,39 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 # except ValueError: # pass # layers[0] = layers[0].module -# +# # layers[0] = layers[0].cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu() # torch.cuda.empty_cache() -# +# # outs = torch.zeros_like(inps) # attention_mask = cache['attention_mask'] # position_ids = cache['position_ids'] -# +# # for i in range(len(layers)): # print(i) # layer = layers[i].to(dev) -# +# # if args.nearest: # subset = find_layers(layer) # for name in subset: # quantizer = quant.Quantizer() -# quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) +# quantizer.configure(args.bits, perchannel=True, sym=args.sym, mse=False) # W = subset[name].weight.data # quantizer.find_params(W, weight=True) # subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) -# +# # for j in range(nsamples): # outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] # layers[i] = layer.cpu() # del layer # torch.cuda.empty_cache() # inps, outs = outs, inps -# +# # if model.model.norm is not None: # model.model.norm = model.model.norm.to(dev) # model.lm_head = model.lm_head.to(dev) -# +# # testenc = testenc.to(dev) # nlls = [] # for i in range(nsamples): @@ -740,36 +838,61 @@ def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01 # nlls.append(neg_log_likelihood) # ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) # print(ppl.item()) -# +# # model.config.use_cache = use_cache +def make_quant_linear(module, names, bits, groupsize, name=""): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + delattr(module, attr) + setattr( + module, + attr, + QuantLinear.new( + bits, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + ), + ) + for name1, child in module.named_children(): + make_quant_linear( + child, names, bits, groupsize, name + "." + name1 if name != "" else name1 + ) + + # TODO: perform packing on GPU -def pack(model, quantizers, wbits, groupsize): +def pack(model, quantizers, bits, groupsize): layers = find_layers(model) layers = {n: layers[n] for n in quantizers} - quant.make_quant_linear(model, quantizers, wbits, groupsize) + make_quant_linear(model, quantizers, bits, groupsize) qlayers = find_layers(model, [QuantLinear]) - print('Packing ...') + print("Packing ...") for name in qlayers: print(name) quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') + print("Done.") return model -# def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): +# def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): # from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils # config = LlamaConfig.from_pretrained(model) -# +# # def noop(*args, **kwargs): # pass -# +# # torch.nn.init.kaiming_uniform_ = noop # torch.nn.init.uniform_ = noop # torch.nn.init.normal_ = noop -# +# # torch.set_default_dtype(torch.half) # modeling_utils._init_weights = False # torch.set_default_dtype(torch.half) @@ -781,30 +904,30 @@ def pack(model, quantizers, wbits, groupsize): # for name in ['lm_head']: # if name in layers: # del layers[name] -# quant.make_quant_linear(model, layers, wbits, groupsize) -# +# quant.make_quant_linear(model, layers, bits, groupsize) +# # del layers -# +# # print('Loading model ...') # if checkpoint.endswith('.safetensors'): # from safetensors.torch import load_file as safe_load # model.load_state_dict(safe_load(checkpoint)) # else: # model.load_state_dict(torch.load(checkpoint)) -# +# # if eval: # quant.make_quant_attn(model) # quant.make_quant_norm(model) # if fused_mlp: # quant.make_fused_mlp(model) -# +# # if warmup_autotune: # quant.autotune_warmup_linear(model, transpose=not (eval)) # if eval and fused_mlp: # quant.autotune_warmup_fused(model) # model.seqlen = 2048 # print('Done.') -# +# # return model @@ -814,33 +937,33 @@ def pack(model, quantizers, wbits, groupsize): # model.model.norm = model.model.norm.to(gpus[0]) # import copy # model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0]) -# +# # cache = {'mask': None, 'position_ids': None} -# +# # class MoveModule(nn.Module): -# +# # def __init__(self, module, invalidate_cache): # super().__init__() # self.module = module # self.dev = next(iter(self.module.parameters())).device # self.invalidate_cache=invalidate_cache -# +# # def forward(self, *inp, **kwargs): # inp = list(inp) # if inp[0].device != self.dev: # inp[0] = inp[0].to(self.dev) -# +# # if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache: # cache['mask'] = kwargs['attention_mask'].to(self.dev) # kwargs['attention_mask'] = cache['mask'] -# +# # if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache: # cache['position_ids'] = kwargs['position_ids'].to(self.dev) # kwargs['position_ids'] = cache['position_ids'] -# +# # tmp = self.module(*inp, **kwargs) # return tmp -# +# # layers = model.model.layers # from math import ceil # if not gpu_dist: @@ -852,49 +975,49 @@ def pack(model, quantizers, wbits, groupsize): # assigned_gpus = [0] * (gpu_dist[0]-1) # for i in range(1, len(gpu_dist)): # assigned_gpus = assigned_gpus + [i] * gpu_dist[i] -# +# # remaining_assignments = len(layers)-len(assigned_gpus) - 1 # if remaining_assignments > 0: # assigned_gpus = assigned_gpus + [-1] * remaining_assignments -# +# # assigned_gpus = assigned_gpus + [0] -# +# # for i in range(len(layers)): # layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0) -# +# # model.gpus = gpus -# -# +# +# # def benchmark(model, input_ids, check=False): # input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) # torch.cuda.synchronize() -# +# # cache = {'past': None} -# +# # def clear_past(i): -# +# # def tmp(layer, inp, out): # if cache['past']: # cache['past'][i] = None -# +# # return tmp -# +# # for i, layer in enumerate(model.model.layers): # layer.register_forward_hook(clear_past(i)) -# +# # print('Benchmarking ...') -# +# # if check: # loss = nn.CrossEntropyLoss() # tot = 0. -# +# # def sync(): # if hasattr(model, 'gpus'): # for gpu in model.gpus: # torch.cuda.synchronize(gpu) # else: # torch.cuda.synchronize() -# +# # max_memory = 0 # with torch.no_grad(): # attention_mask = torch.ones((1, input_ids.numel()), device=DEV) @@ -921,9 +1044,11 @@ def pack(model, quantizers, wbits, groupsize): # print('max memory(MiB):', max_memory) -def quantize(model_id: str, wbits: int, groupsize: int): +def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): print("loading model") - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0") + model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float16, device_map="balanced_low_0" + ) print("LOADED model") model.seqlen = 2048 @@ -931,11 +1056,12 @@ def quantize(model_id: str, wbits: int, groupsize: int): nsamples = 128 seed = None - - dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen) + dataloader, testloader = get_loaders( + dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen + ) tick = time.time() - quantizers = sequential(model, dataloader, DEV, nsamples, wbits, groupsize) + quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) print(time.time() - tick) # if args.benchmark: @@ -956,7 +1082,7 @@ def quantize(model_id: str, wbits: int, groupsize: int): # dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) # print(dataset) # llama_eval(model, testloader, DEV) - # + # # if args.test_generation: # gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] # if len(gpus) > 1: @@ -970,20 +1096,57 @@ def quantize(model_id: str, wbits: int, groupsize: int): # streamer = TextStreamer(tokenizer) # with torch.no_grad(): # generated_ids = model.generate(input_ids, streamer=streamer) - # - + # # if args.quant_directory is not None: # export_quant_table(quantizers, args.quant_directory) # if not args.observe and args.save: - # llama_pack(model, quantizers, args.wbits, args.groupsize) + # llama_pack(model, quantizers, args.bits, args.groupsize) # torch.save(model.state_dict(), args.save) # if not args.observe and args.save_safetensors: - pack(model, quantizers, wbits, groupsize) - from safetensors.torch import save_file as safe_save - state_dict = model.state_dict() - state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} - safe_save(state_dict, args.save_safetensors) + pack(model, quantizers, bits, groupsize) + from safetensors.torch import save_file + from transformers.modeling_utils import shard_checkpoint + state_dict = model.state_dict() + state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} + state_dict["gptq_bits"] = torch.LongTensor(bits) + state_dict["gptq_groupsize"] = torch.LongTensor(groupsize) + + shards, index = shard_checkpoint( + state_dict, max_shard_size="10GB", weights_name="model.safetensors" + ) + os.makedirs(output_dir, exist_ok=True) + for shard_file, shard in shards.items(): + save_file( + shard, + os.path.join(output_dir, shard_file), + metadata={ + "format": "pt", + "quantized": "gptq", + "origin": "text-generation-inference", + }, + ) + if index is None: + path_to_weights = os.path.join(save_directory, "model.safetensors") + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = "model.safetensors.index.json" + save_index_file = os.path.join(save_directory, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + config = AutoConfig.from_pretrained(model_id) + config.save_pretrained(output_dir) + logger.info("Saved config") + logger.info("Saving tokenizer") + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(output_dir) + logger.info("Saved tokenizer") diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 55f51f5a..e85e8f2f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -134,13 +134,15 @@ def get_linear(weight, bias, quantize): try: qweight, qzeros, scales, g_idx, bits, groupsize = weight except Exception: - raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.") + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." + ) linear = QuantLinear( qweight, qzeros, scales, - g_idx, + g_idx, bias, bits, groupsize, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 16ef87a5..bc3e284c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -86,19 +86,27 @@ class Weights: def get_multi_weights_col(self, prefixes: List[str], quantize: str): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - # TODO Get that from file to be more generic - bits = 4 - groupsize = 128 + + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -110,14 +118,15 @@ class Weights: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - # TODO Get that from file to be more generic - bits = 4 - groupsize = 128 + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: