Support weight only quantization

This commit is contained in:
zhaosida 2023-09-27 10:33:55 +08:00 committed by Nicolas Patry
parent a049864270
commit 2d13b6ff6c
5 changed files with 58 additions and 2 deletions

View File

@ -26,6 +26,7 @@ enum Quantization {
BitsandbytesFP4, BitsandbytesFP4,
Gptq, Gptq,
Awq, Awq,
Eetq,
} }
impl std::fmt::Display for Quantization { impl std::fmt::Display for Quantization {
@ -46,6 +47,8 @@ impl std::fmt::Display for Quantization {
} }
Quantization::Awq => { Quantization::Awq => {
write!(f, "awq") write!(f, "awq")
Quantization::Eetq => {
write!(f, "eetq")
} }
} }
} }

View File

@ -2,6 +2,7 @@ include Makefile-flash-att
include Makefile-flash-att-v2 include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-eetq
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests

13
server/Makefile-eetq Normal file
View File

@ -0,0 +1,13 @@
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
eetq:
# Clone eetq
pip install packaging
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit)
cd eetq && python setup.py build
install-eetq: build-eetq
cd eetq && python setup.py install

View File

@ -18,6 +18,7 @@ class Quantization(str, Enum):
bitsandbytes_fp4 = "bitsandbytes-fp4" bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
eetq = "eetq"
class Dtype(str, Enum): class Dtype(str, Enum):

View File

@ -42,6 +42,13 @@ elif CAN_EXLLAMA:
from typing import Optional from typing import Optional
HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm
HAS_EETQ = True
except ImportError:
pass
# Monkey patching # Monkey patching
@classmethod @classmethod
def load_layer_norm(cls, prefix, weights, eps): def load_layer_norm(cls, prefix, weights, eps):
@ -120,6 +127,30 @@ class FastLinear(nn.Module):
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
class EETQLinear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output
class Linear8bitLt(nn.Module): class Linear8bitLt(nn.Module):
def __init__( def __init__(
self, self,
@ -214,7 +245,14 @@ class Linear4bit(nn.Module):
def get_linear(weight, bias, quantize): def get_linear(weight, bias, quantize):
if quantize is None: if quantize is None:
linear = FastLinear(weight, bias) linear = FastLinear(weight, bias)
elif quantize == "eetq":
if HAS_EETQ:
linear = EETQLinear(weight, bias)
else:
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
import warnings
warnings.warn("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce", DeprecationWarning)
linear = Linear8bitLt( linear = Linear8bitLt(
weight, weight,
bias, bias,
@ -298,8 +336,8 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
# GPTQ and AWQ don't quantize heads (nor embeddings) # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "eetq"]:
quantize = None quantize = None
else: else:
quantize = config.quantize quantize = config.quantize