Support weight only quantization

This commit is contained in:
zhaosida 2023-09-27 10:33:55 +08:00
parent c8a01d7591
commit f29af8f38e
5 changed files with 56 additions and 1 deletions

View File

@ -25,6 +25,7 @@ enum Quantization {
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
Eetq,
}
impl std::fmt::Display for Quantization {
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => {
write!(f, "gptq")
}
Quantization::Eetq => {
write!(f, "eetq")
}
}
}
}

View File

@ -1,6 +1,7 @@
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-eetq
unit-tests:
pytest -s -vv -m "not private" tests

13
server/Makefile-eetq Normal file
View File

@ -0,0 +1,13 @@
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
eetq:
# Clone eetq
pip install packaging
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit)
cd eetq && python setup.py build
install-eetq: build-eetq
cd eetq && python setup.py install

View File

@ -17,6 +17,7 @@ class Quantization(str, Enum):
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
eetq = "eetq"
class Dtype(str, Enum):

View File

@ -35,6 +35,13 @@ elif CAN_EXLLAMA:
from typing import Optional
HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm
HAS_EETQ = True
except ImportError:
pass
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
@ -113,6 +120,30 @@ class FastLinear(nn.Module):
return F.linear(input, self.weight, self.bias)
class WeightOnlyLinear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output
class Linear8bitLt(nn.Module):
def __init__(
self,
@ -207,6 +238,11 @@ class Linear4bit(nn.Module):
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "eetq":
if HAS_EETQ:
linear = WeightOnlyLinear(weight, bias)
else:
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
elif quantize == "bitsandbytes":
linear = Linear8bitLt(
weight,
@ -284,7 +320,7 @@ class TensorParallelHead(SuperLayer):
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
if config.quantize in ["gptq", "eetq"]:
quantize = None
else:
quantize = config.quantize