mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
247 lines
8.1 KiB
Python
247 lines
8.1 KiB
Python
from typing import Optional
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from text_generation_server.layers.marlin import GPTQMarlinLinear
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
if SYSTEM == "rocm":
|
|
try:
|
|
from vllm import _custom_C
|
|
except Exception as e:
|
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
|
|
|
|
|
class FastLinear(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
weight,
|
|
bias,
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
if bias is not None:
|
|
self.bias = torch.nn.Parameter(bias, requires_grad=False)
|
|
else:
|
|
self.bias = None
|
|
|
|
@classmethod
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
if bias:
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
else:
|
|
bias = None
|
|
return cls(weight, bias)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
return F.linear(input, self.weight, self.bias)
|
|
|
|
|
|
class FastLinearROCm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
weight,
|
|
bias,
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
if bias is not None:
|
|
self.bias = torch.nn.Parameter(bias)
|
|
else:
|
|
self.bias = None
|
|
|
|
@classmethod
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
if bias:
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
else:
|
|
bias = None
|
|
return cls(weight, bias)
|
|
|
|
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
|
weight = self.weight
|
|
bias = self.bias
|
|
|
|
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
|
batched = False
|
|
inp_shape = inp.shape
|
|
|
|
if inp.dim() == 3:
|
|
inp = inp.view(-1, inp_shape[-1])
|
|
batched = True
|
|
|
|
m, k = weight.shape[0], inp_shape[1]
|
|
out = torch.empty(
|
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
|
)
|
|
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
|
_custom_C.LLMM1(weight, inp, out, 8)
|
|
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
|
_custom_C.LLMM1(weight, inp, out, 4)
|
|
else:
|
|
out = F.linear(inp, weight)
|
|
|
|
if batched:
|
|
out.view(*inp_shape[:-1], out.shape[-1])
|
|
|
|
if bias is not None:
|
|
out = out + bias
|
|
return out
|
|
return F.linear(inp, self.weight, self.bias)
|
|
|
|
|
|
def get_linear(weight, bias, quantize):
|
|
if quantize is None:
|
|
if SYSTEM == "rocm":
|
|
linear = FastLinearROCm(weight, bias)
|
|
else:
|
|
linear = FastLinear(weight, bias)
|
|
elif quantize == "eetq":
|
|
try:
|
|
from text_generation_server.layers.eetq import EETQLinear
|
|
|
|
linear = EETQLinear(weight, bias)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
|
)
|
|
elif quantize == "fp8":
|
|
from text_generation_server.layers.fp8 import Fp8Linear
|
|
|
|
linear = Fp8Linear(weight, bias)
|
|
elif quantize == "bitsandbytes":
|
|
try:
|
|
from text_generation_server.layers.bnb import (
|
|
warn_deprecate_bnb,
|
|
Linear8bitLt,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
warn_deprecate_bnb()
|
|
linear = Linear8bitLt(
|
|
weight,
|
|
bias,
|
|
has_fp16_weights=False,
|
|
threshold=6.0,
|
|
)
|
|
if bias is not None:
|
|
linear.bias = nn.Parameter(bias)
|
|
elif quantize == "bitsandbytes-fp4":
|
|
try:
|
|
from text_generation_server.layers.bnb import Linear4bit
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
linear = Linear4bit(
|
|
weight,
|
|
bias,
|
|
quant_type="fp4",
|
|
)
|
|
elif quantize == "bitsandbytes-nf4":
|
|
try:
|
|
from text_generation_server.layers.bnb import Linear4bit
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
linear = Linear4bit(
|
|
weight,
|
|
bias,
|
|
quant_type="nf4",
|
|
)
|
|
elif quantize == "exl2":
|
|
from text_generation_server.layers.exl2 import Exl2Weight
|
|
|
|
if not isinstance(weight, Exl2Weight):
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `exl2` compatible, loader needs to be updated."
|
|
)
|
|
|
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
|
|
|
linear = ExllamaQuantLinear(weight, bias)
|
|
|
|
elif quantize == "gptq":
|
|
from text_generation_server.layers.gptq import GPTQWeight
|
|
|
|
if not isinstance(weight, GPTQWeight):
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
|
)
|
|
|
|
if weight.use_exllama:
|
|
try:
|
|
from text_generation_server.layers.gptq import (
|
|
ExllamaQuantLinear,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
|
)
|
|
|
|
linear = ExllamaQuantLinear(weight, bias)
|
|
else:
|
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|
|
|
linear = QuantLinear(
|
|
weight.qweight,
|
|
weight.qzeros,
|
|
weight.scales,
|
|
weight.g_idx,
|
|
bias,
|
|
weight.bits,
|
|
weight.groupsize,
|
|
)
|
|
elif quantize == "awq":
|
|
from text_generation_server.layers.gptq import GPTQWeight
|
|
|
|
if not isinstance(weight, GPTQWeight):
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
|
)
|
|
if SYSTEM == "rocm":
|
|
raise NotImplementedError(
|
|
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
|
)
|
|
try:
|
|
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
|
|
|
linear = WQLinear(
|
|
w_bit=weight.bits,
|
|
group_size=weight.groupsize,
|
|
qweight=weight.qweight,
|
|
qzeros=weight.qzeros,
|
|
scales=weight.scales,
|
|
bias=bias is not None,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
|
)
|
|
elif quantize == "marlin":
|
|
from text_generation_server.layers.marlin import (
|
|
GPTQMarlinWeight,
|
|
MarlinLinear,
|
|
MarlinWeight,
|
|
)
|
|
|
|
if isinstance(weight, GPTQMarlinWeight):
|
|
linear = GPTQMarlinLinear(
|
|
weight=weight,
|
|
bias=bias,
|
|
)
|
|
elif isinstance(weight, MarlinWeight):
|
|
linear = MarlinLinear(weight=weight, bias=bias)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `marlin` compatible, loader needs to be updated."
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
|
return linear
|