From 00dede8a631560d307bb1037ccb3d72f04acca81 Mon Sep 17 00:00:00 2001 From: Abhinav Kulkarni Date: Wed, 13 Sep 2023 11:08:22 +0000 Subject: [PATCH] Added AWQ support for FlashLlama models --- launcher/src/main.rs | 4 ++ server/requirements.txt | 2 + server/text_generation_server/cli.py | 1 + .../text_generation_server/models/__init__.py | 4 ++ .../models/flash_llama.py | 2 +- server/text_generation_server/utils/layers.py | 18 +++++++- .../text_generation_server/utils/weights.py | 46 +++++++++++++++++++ 7 files changed, 74 insertions(+), 3 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cbb6f25d..09e32f89 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -25,6 +25,7 @@ enum Quantization { BitsandbytesNF4, BitsandbytesFP4, Gptq, + Awq, } impl std::fmt::Display for Quantization { @@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Awq => { + write!(f, "awq") + } } } } diff --git a/server/requirements.txt b/server/requirements.txt index 1b038cca..ac3ac9fa 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" +# Custom 4-bit GEMM AWQ kernels +git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e3fda07f..e0b8c0fe 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -17,6 +17,7 @@ class Quantization(str, Enum): bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" + awq = "awq" class Dtype(str, Enum): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e6fe1372..8fc787c6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -268,6 +268,10 @@ def get_model( raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) + if quantize == "awq": + raise ValueError( + "awq quantization is not supported for AutoModel" + ) elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): raise ValueError( "4bit quantization is not supported for AutoModel" diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 063aa01e..d2ed0b15 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -62,7 +62,7 @@ class FlashLlama(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize == "gptq": + if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) model = FlashLlamaForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c1c36194..f0227177 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -17,6 +17,7 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear +from text_generation_server.utils.awq.quantize.qmodule import WQLinear try: major, _minor = torch.cuda.get_device_capability() @@ -248,6 +249,19 @@ def get_linear(weight, bias, quantize): bits, groupsize, ) + elif quantize == "awq": + try: + qweight, qzeros, scales, bits, groupsize = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `awq` compatible, loader needs to be updated." + ) + in_features = qweight.shape[0] + out_features = qweight.shape[1] * 32 // bits + linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device) + linear.qweight = qweight + linear.qzeros = qzeros + linear.scales = scales else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -283,8 +297,8 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ doesn't quantize heads (nor embeddings) - if config.quantize == "gptq": + # GPTQ and AWQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq"]: quantize = None else: quantize = config.quantize diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2ef7ad39..c5562a4f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -150,6 +150,19 @@ class Weights: bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + if quantize == "awq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, bits, groupsize) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -194,6 +207,25 @@ class Weights: bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + elif quantize == "awq": + try: + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -282,6 +314,20 @@ class Weights: g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + elif quantize == "awq": + bits, groupsize = self._get_gptq_params() + + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) + + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + + weight = (qweight, qzeros, scales, bits, groupsize) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight