mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Support weight only quantization
This commit is contained in:
parent
c8a01d7591
commit
f29af8f38e
@ -25,6 +25,7 @@ enum Quantization {
|
|||||||
BitsandbytesNF4,
|
BitsandbytesNF4,
|
||||||
BitsandbytesFP4,
|
BitsandbytesFP4,
|
||||||
Gptq,
|
Gptq,
|
||||||
|
Eetq,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
|
Quantization::Eetq => {
|
||||||
|
write!(f, "eetq")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
include Makefile-flash-att-v2
|
include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
|
include Makefile-eetq
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
13
server/Makefile-eetq
Normal file
13
server/Makefile-eetq
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
|
||||||
|
|
||||||
|
eetq:
|
||||||
|
# Clone eetq
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||||
|
|
||||||
|
build-eetq: eetq
|
||||||
|
cd eetq && git fetch && git checkout $(eetq_commit)
|
||||||
|
cd eetq && python setup.py build
|
||||||
|
|
||||||
|
install-eetq: build-eetq
|
||||||
|
cd eetq && python setup.py install
|
@ -17,6 +17,7 @@ class Quantization(str, Enum):
|
|||||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
|
eetq = "eetq"
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
|
@ -35,6 +35,13 @@ elif CAN_EXLLAMA:
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
HAS_EETQ = False
|
||||||
|
try:
|
||||||
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
|
HAS_EETQ = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
@ -113,6 +120,30 @@ class FastLinear(nn.Module):
|
|||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightOnlyLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
device = weight.device
|
||||||
|
weight = torch.t(weight).contiguous().cpu()
|
||||||
|
weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
self.weight = weight.cuda(device)
|
||||||
|
self.scale = scale.cuda(device)
|
||||||
|
self.bias = bias.cuda(device) if bias is not None else None
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||||
|
output = output + self.bias if self.bias is not None else output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -207,6 +238,11 @@ class Linear4bit(nn.Module):
|
|||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
|
elif quantize == "eetq":
|
||||||
|
if HAS_EETQ:
|
||||||
|
linear = WeightOnlyLinear(weight, bias)
|
||||||
|
else:
|
||||||
|
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
linear = Linear8bitLt(
|
linear = Linear8bitLt(
|
||||||
weight,
|
weight,
|
||||||
@ -284,7 +320,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
should_gather = False
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize in ["gptq", "eetq"]:
|
||||||
quantize = None
|
quantize = None
|
||||||
else:
|
else:
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
|
Loading…
Reference in New Issue
Block a user