From 2d13b6ff6c83af77a0d25ed28fc3900cac8ddd93 Mon Sep 17 00:00:00 2001 From: zhaosida Date: Wed, 27 Sep 2023 10:33:55 +0800 Subject: [PATCH] Support weight only quantization --- launcher/src/main.rs | 3 ++ server/Makefile | 1 + server/Makefile-eetq | 13 ++++++ server/text_generation_server/cli.py | 1 + server/text_generation_server/utils/layers.py | 42 ++++++++++++++++++- 5 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 server/Makefile-eetq diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 09e32f89..ce5d6d70 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -26,6 +26,7 @@ enum Quantization { BitsandbytesFP4, Gptq, Awq, + Eetq, } impl std::fmt::Display for Quantization { @@ -46,6 +47,8 @@ impl std::fmt::Display for Quantization { } Quantization::Awq => { write!(f, "awq") + Quantization::Eetq => { + write!(f, "eetq") } } } diff --git a/server/Makefile b/server/Makefile index b21d79d4..52543e3d 100644 --- a/server/Makefile +++ b/server/Makefile @@ -2,6 +2,7 @@ include Makefile-flash-att include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq +include Makefile-eetq unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-eetq b/server/Makefile-eetq new file mode 100644 index 00000000..5e8e9830 --- /dev/null +++ b/server/Makefile-eetq @@ -0,0 +1,13 @@ +eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1 + +eetq: + # Clone eetq + pip install packaging + git clone https://github.com/NetEase-FuXi/EETQ.git eetq + +build-eetq: eetq + cd eetq && git fetch && git checkout $(eetq_commit) + cd eetq && python setup.py build + +install-eetq: build-eetq + cd eetq && python setup.py install diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 7464934f..cf9596c9 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -18,6 +18,7 @@ class Quantization(str, Enum): bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" + eetq = "eetq" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index fb27764c..dd1e54a7 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -42,6 +42,13 @@ elif CAN_EXLLAMA: from typing import Optional +HAS_EETQ = False +try: + from EETQ import quant_weights, w8_a16_gemm + HAS_EETQ = True +except ImportError: + pass + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -120,6 +127,30 @@ class FastLinear(nn.Module): return F.linear(input, self.weight, self.bias) +class EETQLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -214,7 +245,14 @@ class Linear4bit(nn.Module): def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) + elif quantize == "eetq": + if HAS_EETQ: + linear = EETQLinear(weight, bias) + else: + raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") elif quantize == "bitsandbytes": + import warnings + warnings.warn("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce", DeprecationWarning) linear = Linear8bitLt( weight, bias, @@ -298,8 +336,8 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ and AWQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq"]: + # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq", "eetq"]: quantize = None else: quantize = config.quantize