From a0d55358d20cf3f86083b79765b80704ce9d8404 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jul 2023 12:00:27 +0100 Subject: [PATCH] feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671) - Current PR is not great because we're side stepping the `Weights.__init__` but Weights shouldn't requires anything related to the config or the model_id as it aims to be a simple Wrapper over multi file loading. - Ideal solution would be to use something like Rust enum ``` enum Quantize{ Bitandbytes(Bitsandbytes), GPTQ(bits: usize, groupsize: usize) ``` And passing that around during load. Unfortunately we don't have access to this, so for now, side-stepping seems easier. - Re-enabling groupsize<0 with exllama (confirmed it works.) Helps #601 In next steps we should make sure our quantization script uses that format and make it standard. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/bloom.py | 2 ++ .../flash_santacoder_modeling.py | 2 +- .../models/custom_modeling/opt_modeling.py | 20 ++++++------ .../models/flash_llama.py | 4 ++- .../models/flash_neox.py | 2 ++ .../text_generation_server/models/flash_rw.py | 3 ++ .../models/flash_santacoder.py | 5 +++ .../models/galactica.py | 2 ++ .../text_generation_server/models/gpt_neox.py | 2 ++ server/text_generation_server/models/mpt.py | 2 ++ server/text_generation_server/models/opt.py | 2 ++ .../text_generation_server/utils/weights.py | 31 ++++++++++++------- 12 files changed, 53 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 101da207..79fb60c6 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -76,6 +76,8 @@ class BLOOMSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = BloomForCausalLM(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c16b2bf7..2dd0a5ee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -76,7 +76,7 @@ def _load_multi_mqa_gptq( g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - bits, groupsize = weights._get_gptq_qparams() + bits, groupsize = weights._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index aa052b08..5d1a4b0d 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -130,17 +130,17 @@ class OPTAttention(nn.Module): process_group=None, ): super().__init__() - embed_dim = config.embed_dim + hidden_size = config.hidden_size num_heads = config.num_attention_heads - self.embed_dim = embed_dim + self.hidden_size = hidden_size self.num_heads = num_heads self.dropout = config.dropout - self.head_dim = embed_dim // num_heads + self.head_dim = hidden_size // num_heads - if (self.head_dim * num_heads) != self.embed_dim: + if (self.head_dim * num_heads) != self.hidden_size: raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 @@ -153,7 +153,7 @@ class OPTAttention(nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // process_group.size() - self.embed_dim = self.embed_dim // process_group.size() + self.hidden_size = self.hidden_size // process_group.size() self.q_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias @@ -300,9 +300,9 @@ class OPTAttention(nn.Module): attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) attn_output = self.out_proj(attn_output) @@ -313,7 +313,7 @@ class OPTDecoderLayer(nn.Module): def __init__(self, layer_id: int, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group - self.embed_dim = config.hidden_size + self.hidden_size = config.hidden_size prefix = f"model.decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, @@ -352,7 +352,7 @@ class OPTDecoderLayer(nn.Module): ]: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index b699799e..96fb0c26 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -55,13 +55,15 @@ class FlashLlama(FlashCausalLM): config = LlamaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) - config.quantize = quantize model = FlashLlamaForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 61004d8e..58f345a9 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -52,6 +52,8 @@ class FlashNeoXSharded(FlashCausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = FlashGPTNeoXForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 55d555fc..9e0080a9 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -58,6 +58,9 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize + if config.quantize == "gptq": + weights._set_gptq_params(model_id) + model = FlashRWForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 415ec2df..29505902 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -4,7 +4,10 @@ import torch.distributed from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig from typing import Optional, List +import json +import os +from huggingface_hub import hf_hub_download from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, @@ -59,6 +62,8 @@ class FlashSantacoderSharded(FlashCausalLM): process_group=self.process_group, aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 01e58bad..d4211734 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -191,6 +191,8 @@ class GalacticaSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 91877fa0..accedf14 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -56,6 +56,8 @@ class GPTNeoxSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = GPTNeoxForCausalLM(config, weights) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index a4fe5105..909d9852 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -78,6 +78,8 @@ class MPTSharded(CausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) config.quantize = quantize model = MPTForCausalLM(config, weights) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index d407b44a..f3a23d07 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -54,6 +54,8 @@ class OPTSharded(CausalLM): weights = Weights( filenames, device=device, dtype=dtype, process_group=self.process_group ) + if config.quantize == "gptq": + weights._set_gptq_params(model_id) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index dae53509..0330402d 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,6 +3,8 @@ from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError import torch from loguru import logger +from huggingface_hub import hf_hub_download +import json class Weights: @@ -128,7 +130,7 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - bits, groupsize = self._get_gptq_qparams() + bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -138,7 +140,7 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize = self._get_gptq_qparams() + bits, groupsize = self._get_gptq_params() if bits != 4: use_exllama = False @@ -185,11 +187,8 @@ class Weights: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) else: - raise RuntimeError( - "Using exllama GPTQ kernel with groupsize<1 is not supported" - ) - # qzeros = self.get_tensor(f"{prefix}.qzeros") - # scales = self.get_tensor(f"{prefix}.scales") + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") # For tp > 1, at this point we know we do not use act-order if self.process_group.size() == 1: @@ -208,17 +207,25 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_qparams(self) -> Tuple[int, int]: + def _get_gptq_params(self) -> Tuple[int, int]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() except (SafetensorError, RuntimeError) as e: try: - import os - - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + bits = self.gptq_bits + groupsize = self.gptq_groupsize except Exception: raise e return bits, groupsize + + def _set_gptq_params(self, model_id): + try: + filename = hf_hub_download(model_id, filename="quantize_config.json") + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["bits"] + self.gptq_groupsize = data["group_size"] + except Exception: + pass