mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add 4bit bnb quantization
This commit is contained in:
parent
794767a98d
commit
c9a78bbe0f
@ -9,7 +9,7 @@ from typing import List
|
|||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn import Int8Params
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Linear4bit(nn.Module):
|
||||||
|
def __init__(self, weight, bias, quant_type):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = Params4bit(
|
||||||
|
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
||||||
|
)
|
||||||
|
self.compute_dtype = None
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if getattr(self.weight, "quant_state", None) is None:
|
||||||
|
print(
|
||||||
|
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||||
|
)
|
||||||
|
inp_dtype = x.dtype
|
||||||
|
if self.compute_dtype is not None:
|
||||||
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
|
out = bnb.matmul_4bit(
|
||||||
|
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.to(inp_dtype)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
|
|||||||
)
|
)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
linear.bias = nn.Parameter(bias)
|
linear.bias = nn.Parameter(bias)
|
||||||
|
elif quantize == "bitsandbytes-fp4":
|
||||||
|
linear = Linear4bit(
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
quant_type="fp4",
|
||||||
|
)
|
||||||
|
elif quantize == "bitsandbytes-nf4":
|
||||||
|
linear = Linear4bit(
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
quant_type="nf4",
|
||||||
|
)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||||
|
Loading…
Reference in New Issue
Block a user