diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d7b4c0cc..97257f95 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -9,7 +9,7 @@ from typing import List HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params + from bitsandbytes.nn import Int8Params, Params4bit except ImportError: HAS_BITS_AND_BYTES = False @@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module): return out +class Linear4bit(nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out + + def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) @@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize): ) if bias is not None: linear.bias = nn.Parameter(bias) + elif quantize == "bitsandbytes-fp4": + linear = Linear4bit( + weight, + bias, + quant_type="fp4", + ) + elif quantize == "bitsandbytes-nf4": + linear = Linear4bit( + weight, + bias, + quant_type="nf4", + ) elif quantize == "gptq": try: qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight