mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Support weight only quantization
This commit is contained in:
parent
a049864270
commit
2d13b6ff6c
@ -26,6 +26,7 @@ enum Quantization {
|
|||||||
BitsandbytesFP4,
|
BitsandbytesFP4,
|
||||||
Gptq,
|
Gptq,
|
||||||
Awq,
|
Awq,
|
||||||
|
Eetq,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -46,6 +47,8 @@ impl std::fmt::Display for Quantization {
|
|||||||
}
|
}
|
||||||
Quantization::Awq => {
|
Quantization::Awq => {
|
||||||
write!(f, "awq")
|
write!(f, "awq")
|
||||||
|
Quantization::Eetq => {
|
||||||
|
write!(f, "eetq")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ include Makefile-flash-att
|
|||||||
include Makefile-flash-att-v2
|
include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
|
include Makefile-eetq
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
13
server/Makefile-eetq
Normal file
13
server/Makefile-eetq
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
|
||||||
|
|
||||||
|
eetq:
|
||||||
|
# Clone eetq
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||||
|
|
||||||
|
build-eetq: eetq
|
||||||
|
cd eetq && git fetch && git checkout $(eetq_commit)
|
||||||
|
cd eetq && python setup.py build
|
||||||
|
|
||||||
|
install-eetq: build-eetq
|
||||||
|
cd eetq && python setup.py install
|
@ -18,6 +18,7 @@ class Quantization(str, Enum):
|
|||||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
awq = "awq"
|
awq = "awq"
|
||||||
|
eetq = "eetq"
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
|
@ -42,6 +42,13 @@ elif CAN_EXLLAMA:
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
HAS_EETQ = False
|
||||||
|
try:
|
||||||
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
|
HAS_EETQ = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
@ -120,6 +127,30 @@ class FastLinear(nn.Module):
|
|||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class EETQLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
device = weight.device
|
||||||
|
weight = torch.t(weight).contiguous().cpu()
|
||||||
|
weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
self.weight = weight.cuda(device)
|
||||||
|
self.scale = scale.cuda(device)
|
||||||
|
self.bias = bias.cuda(device) if bias is not None else None
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||||
|
output = output + self.bias if self.bias is not None else output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -214,7 +245,14 @@ class Linear4bit(nn.Module):
|
|||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
|
elif quantize == "eetq":
|
||||||
|
if HAS_EETQ:
|
||||||
|
linear = EETQLinear(weight, bias)
|
||||||
|
else:
|
||||||
|
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
|
import warnings
|
||||||
|
warnings.warn("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce", DeprecationWarning)
|
||||||
linear = Linear8bitLt(
|
linear = Linear8bitLt(
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
@ -298,8 +336,8 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ and AWQ don't quantize heads (nor embeddings)
|
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||||
quantize = None
|
quantize = None
|
||||||
else:
|
else:
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
|
Loading…
Reference in New Issue
Block a user