mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
|
import torch
|
||
|
from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers
|
||
|
from loguru import logger
|
||
|
|
||
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||
|
|
||
|
def ext_q4_matmul(x, q4, q4_width):
|
||
|
"""Matrix multiplication, returns x @ q4"""
|
||
|
outshape = x.shape[:-1] + (q4_width,)
|
||
|
x = x.view(-1, x.shape[-1])
|
||
|
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||
|
|
||
|
q4_matmul(x, q4, output)
|
||
|
|
||
|
return output.view(outshape)
|
||
|
|
||
|
|
||
|
import os
|
||
|
RANK = os.getenv("RANK", "0")
|
||
|
DEVICE = torch.device(f"cuda:{RANK}")
|
||
|
MAX_TOTAL_TOKENS = 1
|
||
|
MAX_INNER_OUTER_DIM = 0
|
||
|
MAX_DQ_BUFFER_SIZE = 0
|
||
|
|
||
|
|
||
|
def create_buffers():
|
||
|
temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE)
|
||
|
temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE)
|
||
|
logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}")
|
||
|
|
||
|
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||
|
|
||
|
matmul_recons_thd = 8
|
||
|
matmul_fused_remap = False
|
||
|
matmul_no_half2 = False
|
||
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||
|
|
||
|
class Ex4bitLinear:
|
||
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||
|
def __init__(self, qweight, qzeros, scales, bias, bits):
|
||
|
assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4"
|
||
|
|
||
|
global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE
|
||
|
dq = qweight.numel() * 8
|
||
|
if dq > MAX_DQ_BUFFER_SIZE:
|
||
|
MAX_DQ_BUFFER_SIZE = dq
|
||
|
|
||
|
width = qweight.shape[1]
|
||
|
if width > MAX_INNER_OUTER_DIM:
|
||
|
MAX_INNER_OUTER_DIM = width
|
||
|
height = qweight.shape[0] * 8
|
||
|
if height > MAX_INNER_OUTER_DIM:
|
||
|
MAX_INNER_OUTER_DIM = height
|
||
|
|
||
|
# prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ)
|
||
|
|
||
|
|
||
|
self.q4 = make_q4(
|
||
|
qweight,
|
||
|
qzeros,
|
||
|
scales,
|
||
|
# Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it
|
||
|
torch.zeros((0, 0), device=torch.device("meta")),
|
||
|
DEVICE.index
|
||
|
)
|
||
|
self.bias = bias if bias is not None else None
|
||
|
self.width = width
|
||
|
|
||
|
# # Infer groupsize from height of qzeros
|
||
|
# self.groupsize = None
|
||
|
# if self.qzeros.shape[0] > 1:
|
||
|
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||
|
|
||
|
# if self.groupsize is not None:
|
||
|
# assert groupsize == self.groupsize
|
||
|
|
||
|
# # Handle act-order matrix
|
||
|
# if self.g_idx is not None:
|
||
|
# if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||
|
# self.act_order = True
|
||
|
# else:
|
||
|
# self.act_order = False
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = ext_q4_matmul(x, self.q4, self.width)
|
||
|
|
||
|
if self.bias is not None:
|
||
|
out.add_(self.bias)
|
||
|
return out
|