From 16d0fb04ae40dc13b42fd3ce9dfc90ca2c426b79 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Jun 2023 16:59:31 +0200 Subject: [PATCH] Santacoder GPTQ support (quantized model seems awful, not sure if it's prompting or the quantization itself). --- server/text_generation_server/cli.py | 4 + .../flash_santacoder_modeling.py | 81 ++++++++++++++++++- .../models/flash_llama.py | 25 ++++-- .../utils/gptq/quantize.py | 21 +++-- 4 files changed, 116 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e872f833..aeb1f13b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -160,6 +160,8 @@ def quantize( json_output: bool = False, trust_remote_code: bool = False, upload_to_model_id: Optional[str] = None, + percdamp: float = 0.01, + act_order: bool = False, ): download_weights( model_id=model_id, @@ -176,6 +178,8 @@ def quantize( output_dir=output_dir, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, + percdamp=percdamp, + act_order=act_order, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 00cc47b6..a14967c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -21,6 +21,81 @@ from text_generation_server.utils.layers import ( def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): + + if config.quantize == "gptq": + return _load_multi_mqa_gptq( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + else: + return _load_multi_mqa( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + + +def _load_multi_mqa_gptq( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + slice_ = weights._get_slice(f"{prefix}.c_attn.qweight") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + qweight = torch.cat([q_tensor, kv_tensor], dim=1) + + slice_ = weights._get_slice(f"{prefix}.c_attn.scales") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + scales = torch.cat([q_tensor, kv_tensor], dim=1) + + slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") + shape = slice_.get_shape() + block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert 2 * head_size % (32 // 4) == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] + qzeros = torch.cat([q_tensor, kv_tensor], dim=1) + + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + else: + raise NotImplementedError("Gptq loading with santacoder is not implemented") + + +def _load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -93,7 +168,9 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_multi_weights_col( + [prefix], quantize=config.quantize, dim=0 + ) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -106,7 +183,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index eb216a20..a80d58cb 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoConfig -from transformers.models.llama import LlamaTokenizer +from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from typing import Optional from text_generation_server.models import FlashCausalLM @@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM): else: raise NotImplementedError("FlashLlama is only available on GPU") - tokenizer = LlamaTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + try: + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + except Exception: + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index abde0b02..5a4ed8da 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -240,7 +240,7 @@ class GPTQ: print(table.draw().split("\n")[-2]) def fasterquant( - self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" + self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" ): self.layer.to(self.dev) @@ -263,7 +263,7 @@ class GPTQ: H[dead, dead] = 1 W[:, dead] = 0 - if actorder: + if act_order: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] H = H[perm][:, perm] @@ -334,7 +334,7 @@ class GPTQ: groupsize = groupsize if groupsize != -1 else self.columns g_idx = [i // groupsize for i in range(self.columns)] g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) - if actorder: + if act_order: invperm = torch.argsort(perm) Q = Q[:, invperm] g_idx = g_idx[invperm] @@ -697,7 +697,7 @@ def sequential( scale, zero, g_idx, error = gptq[name].fasterquant( percdamp=percdamp, groupsize=groupsize, - actorder=act_order, + act_order=act_order, name=name, ) quantizers[f"{prefix}.{i}.{name}"] = ( @@ -775,6 +775,8 @@ def quantize( output_dir: str, trust_remote_code: bool, upload_to_model_id: Optional[str], + percdamp: float, + act_order: bool, ): print("loading model") model = AutoModelForCausalLM.from_pretrained( @@ -795,7 +797,16 @@ def quantize( ) tick = time.time() - quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) + quantizers = sequential( + model, + dataloader, + DEV, + nsamples, + bits, + groupsize, + percdamp=percdamp, + act_order=act_order, + ) print(time.time() - tick) pack(model, quantizers, bits, groupsize)