mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
from text_generation_server.utils.weights import GPTQWeight
|
|
import torch
|
|
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
|
|
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
|
none_tensor = torch.empty((1, 1), device="meta")
|
|
|
|
|
|
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
|
"""Construct Q4Matrix, return handle"""
|
|
return make_q4(
|
|
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
|
|
)
|
|
|
|
|
|
def ext_q4_matmul(x, q4, q4_width):
|
|
"""Matrix multiplication, returns x @ q4"""
|
|
outshape = x.shape[:-1] + (q4_width,)
|
|
x = x.view(-1, x.shape[-1])
|
|
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
|
|
|
|
q4_matmul(x, q4, output)
|
|
|
|
return output.view(outshape)
|
|
|
|
|
|
MAX_DQ = 1
|
|
MAX_INNER = 1
|
|
ACT_ORDER = False
|
|
DEVICE = None
|
|
|
|
TEMP_STATE = None
|
|
TEMP_DQ = None
|
|
|
|
|
|
def set_device(device):
|
|
global DEVICE
|
|
DEVICE = device
|
|
|
|
|
|
def create_exllama_buffers(max_total_tokens: int):
|
|
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
|
|
|
assert DEVICE is not None, "call set_device first"
|
|
|
|
if not ACT_ORDER:
|
|
max_total_tokens = 1
|
|
|
|
# This temp_state buffer is required to reorder X in the act-order case.
|
|
temp_state = torch.zeros(
|
|
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
|
)
|
|
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
|
|
|
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
prepare_buffers(DEVICE, temp_state, temp_dq)
|
|
|
|
matmul_recons_thd = 8
|
|
matmul_fused_remap = False
|
|
matmul_no_half2 = False
|
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
|
|
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
|
|
|
|
|
class Ex4bitLinear(torch.nn.Module):
|
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
|
|
def __init__(self, weight: GPTQWeight, bias):
|
|
super().__init__()
|
|
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
|
assert weight.bits == 4
|
|
|
|
self.device = weight.qweight.device
|
|
self.qweight = weight.qweight
|
|
self.qzeros = weight.qzeros
|
|
self.scales = weight.scales
|
|
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
|
|
self.bias = bias if bias is not None else None
|
|
|
|
if self.g_idx is not None and (
|
|
(self.g_idx == 0).all()
|
|
or torch.equal(
|
|
weight.g_idx.cpu(),
|
|
torch.tensor(
|
|
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
|
|
dtype=torch.int32,
|
|
),
|
|
)
|
|
):
|
|
self.empty_g_idx = True
|
|
self.g_idx = None
|
|
|
|
assert self.device.type == "cuda"
|
|
assert self.device.index is not None
|
|
|
|
self.q4 = ext_make_q4(
|
|
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
|
|
)
|
|
|
|
self.height = weight.qweight.shape[0] * 8
|
|
self.width = weight.qweight.shape[1]
|
|
|
|
# Infer groupsize from height of qzeros
|
|
self.groupsize = None
|
|
if self.qzeros.shape[0] > 1:
|
|
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
|
|
|
if self.groupsize is not None:
|
|
assert weight.groupsize == self.groupsize
|
|
|
|
# Handle act-order matrix
|
|
if self.g_idx is not None:
|
|
if self.groupsize is None:
|
|
raise ValueError("Found group index but no groupsize. What do?")
|
|
self.act_order = True
|
|
else:
|
|
self.act_order = False
|
|
|
|
DEVICE = self.qweight.device
|
|
|
|
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
|
|
|
|
if self.act_order:
|
|
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
|
|
|
ACT_ORDER = True
|
|
|
|
def forward(self, x):
|
|
out = ext_q4_matmul(x, self.q4, self.width)
|
|
|
|
if self.bias is not None:
|
|
out.add_(self.bias)
|
|
return out
|