text-generation-inference/server/text_generation_server/utils/gptq/exllama.py
2023-07-20 15:36:53 +00:00

90 lines
3.0 KiB
Python

import torch
from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers
from loguru import logger
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
import os
RANK = os.getenv("RANK", "0")
DEVICE = torch.device(f"cuda:{RANK}")
MAX_TOTAL_TOKENS = 1
MAX_INNER_OUTER_DIM = 0
MAX_DQ_BUFFER_SIZE = 0
def create_buffers():
temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE)
temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE)
logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}")
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, bias, bits):
assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4"
global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE
dq = qweight.numel() * 8
if dq > MAX_DQ_BUFFER_SIZE:
MAX_DQ_BUFFER_SIZE = dq
width = qweight.shape[1]
if width > MAX_INNER_OUTER_DIM:
MAX_INNER_OUTER_DIM = width
height = qweight.shape[0] * 8
if height > MAX_INNER_OUTER_DIM:
MAX_INNER_OUTER_DIM = height
# prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ)
self.q4 = make_q4(
qweight,
qzeros,
scales,
# Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it
torch.zeros((0, 0), device=torch.device("meta")),
DEVICE.index
)
self.bias = bias if bias is not None else None
self.width = width
# # Infer groupsize from height of qzeros
# self.groupsize = None
# if self.qzeros.shape[0] > 1:
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
# if self.groupsize is not None:
# assert groupsize == self.groupsize
# # Handle act-order matrix
# if self.g_idx is not None:
# if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
# self.act_order = True
# else:
# self.act_order = False
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out