diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cbb6f25d..5255dd7d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -25,6 +25,7 @@ enum Quantization { BitsandbytesNF4, BitsandbytesFP4, Gptq, + Eetq, } impl std::fmt::Display for Quantization { @@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Eetq => { + write!(f, "eetq") + } } } } diff --git a/server/Makefile b/server/Makefile index a4ce6d8b..5e9c69cb 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,7 @@ include Makefile-flash-att include Makefile-flash-att-v2 include Makefile-vllm +include Makefile-eetq unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-eetq b/server/Makefile-eetq new file mode 100644 index 00000000..5e8e9830 --- /dev/null +++ b/server/Makefile-eetq @@ -0,0 +1,13 @@ +eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1 + +eetq: + # Clone eetq + pip install packaging + git clone https://github.com/NetEase-FuXi/EETQ.git eetq + +build-eetq: eetq + cd eetq && git fetch && git checkout $(eetq_commit) + cd eetq && python setup.py build + +install-eetq: build-eetq + cd eetq && python setup.py install diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e3fda07f..eab983a9 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -17,6 +17,7 @@ class Quantization(str, Enum): bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" + eetq = "eetq" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c1c36194..1555ac72 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -35,6 +35,13 @@ elif CAN_EXLLAMA: from typing import Optional +HAS_EETQ = False +try: + from EETQ import quant_weights, w8_a16_gemm + HAS_EETQ = True +except ImportError: + pass + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -113,6 +120,30 @@ class FastLinear(nn.Module): return F.linear(input, self.weight, self.bias) +class WeightOnlyLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -207,6 +238,11 @@ class Linear4bit(nn.Module): def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) + elif quantize == "eetq": + if HAS_EETQ: + linear = WeightOnlyLinear(weight, bias) + else: + raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") elif quantize == "bitsandbytes": linear = Linear8bitLt( weight, @@ -284,7 +320,7 @@ class TensorParallelHead(SuperLayer): should_gather = False # GPTQ doesn't quantize heads (nor embeddings) - if config.quantize == "gptq": + if config.quantize in ["gptq", "eetq"]: quantize = None else: quantize = config.quantize