mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Support weight only quantization
This commit is contained in:
parent
a049864270
commit
2d13b6ff6c
@ -26,6 +26,7 @@ enum Quantization {
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
Awq,
|
||||
Eetq,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
@ -46,6 +47,8 @@ impl std::fmt::Display for Quantization {
|
||||
}
|
||||
Quantization::Awq => {
|
||||
write!(f, "awq")
|
||||
Quantization::Eetq => {
|
||||
write!(f, "eetq")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ include Makefile-flash-att
|
||||
include Makefile-flash-att-v2
|
||||
include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
13
server/Makefile-eetq
Normal file
13
server/Makefile-eetq
Normal file
@ -0,0 +1,13 @@
|
||||
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
|
||||
|
||||
eetq:
|
||||
# Clone eetq
|
||||
pip install packaging
|
||||
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||
|
||||
build-eetq: eetq
|
||||
cd eetq && git fetch && git checkout $(eetq_commit)
|
||||
cd eetq && python setup.py build
|
||||
|
||||
install-eetq: build-eetq
|
||||
cd eetq && python setup.py install
|
@ -18,6 +18,7 @@ class Quantization(str, Enum):
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
eetq = "eetq"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -42,6 +42,13 @@ elif CAN_EXLLAMA:
|
||||
|
||||
from typing import Optional
|
||||
|
||||
HAS_EETQ = False
|
||||
try:
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
HAS_EETQ = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
@ -120,6 +127,30 @@ class FastLinear(nn.Module):
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class EETQLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
self.weight = weight.cuda(device)
|
||||
self.scale = scale.cuda(device)
|
||||
self.bias = bias.cuda(device) if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
return output
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -214,7 +245,14 @@ class Linear4bit(nn.Module):
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "eetq":
|
||||
if HAS_EETQ:
|
||||
linear = EETQLinear(weight, bias)
|
||||
else:
|
||||
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
|
||||
elif quantize == "bitsandbytes":
|
||||
import warnings
|
||||
warnings.warn("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce", DeprecationWarning)
|
||||
linear = Linear8bitLt(
|
||||
weight,
|
||||
bias,
|
||||
@ -298,8 +336,8 @@ class TensorParallelHead(SuperLayer):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ and AWQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
|
Loading…
Reference in New Issue
Block a user