mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
248 lines
7.4 KiB
Python
248 lines
7.4 KiB
Python
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from loguru import logger
|
|
|
|
from text_generation_server.layers.exl2 import Exl2Weight
|
|
from text_generation_server.layers.gptq import GPTQWeight
|
|
|
|
try:
|
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
|
except ImportError:
|
|
logger.error("exllamav2_kernels not installed.")
|
|
raise
|
|
|
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
|
none_tensor = torch.empty((1, 1), device="meta")
|
|
|
|
|
|
@dataclass
|
|
class _ExtraTensors:
|
|
"""Additional generated quantizer tensors."""
|
|
|
|
q_group_map: Optional[torch.Tensor] = None
|
|
q_invperm: Optional[torch.Tensor] = None
|
|
q_perm: Optional[torch.Tensor] = None
|
|
|
|
|
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|
"""Matrix multiplication, returns x @ q4"""
|
|
output_shape = x.shape[:-1] + (q4_width,)
|
|
x = x.view(-1, x.shape[-1])
|
|
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
|
gemm_half_q_half(x, q_handle, output, force_cuda)
|
|
return output.view(output_shape)
|
|
|
|
|
|
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
|
gr = q_groups.tolist()
|
|
group_map = []
|
|
num_groups = len(gr) // 2
|
|
|
|
for i in range(num_groups):
|
|
bits = gr[i * 2]
|
|
if i < num_groups - 1:
|
|
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
|
else:
|
|
qrows = num_qrows - gr[i * 2 + 1]
|
|
rows = qrows * 32 // bits
|
|
for j in range(rows):
|
|
group_map += [i]
|
|
group_map += [rows - j]
|
|
|
|
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
|
|
|
|
|
# Create Q matrix
|
|
|
|
|
|
def ext_make_q_matrix(
|
|
w: Exl2Weight | GPTQWeight,
|
|
extra: _ExtraTensors,
|
|
temp_dq,
|
|
key: Optional[str] = None,
|
|
):
|
|
"""
|
|
Create Q matrix
|
|
"""
|
|
# EXL2
|
|
if isinstance(w, Exl2Weight):
|
|
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
|
extra.q_perm = torch.argsort(w.q_invperm).short()
|
|
|
|
return make_q_matrix(
|
|
w.q_weight,
|
|
extra.q_perm,
|
|
w.q_invperm,
|
|
w.q_scale,
|
|
w.q_scale_max,
|
|
w.q_groups,
|
|
extra.q_group_map,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
temp_dq,
|
|
)
|
|
# GPTQ
|
|
elif isinstance(w, GPTQWeight):
|
|
if w.scales.dtype == torch.float:
|
|
w.scales = w.scales.half()
|
|
|
|
# GPTQ with g_idx (act_order)
|
|
if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
|
extra.q_perm = torch.empty(
|
|
(w.qweight.shape[0] * 8,),
|
|
dtype=torch.short,
|
|
device=w.qweight.device,
|
|
)
|
|
extra.q_invperm = torch.empty_like(extra.q_perm)
|
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
|
return make_q_matrix(
|
|
w.qweight,
|
|
extra.q_perm,
|
|
extra.q_invperm,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
w.qzeros,
|
|
w.scales,
|
|
w.g_idx.cpu(),
|
|
temp_dq,
|
|
)
|
|
# GPTQ without g_idx
|
|
else:
|
|
return make_q_matrix(
|
|
w.qweight,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
w.qzeros,
|
|
w.scales,
|
|
none_tensor,
|
|
temp_dq,
|
|
)
|
|
else:
|
|
RuntimeError("Cannot create handle")
|
|
|
|
|
|
DEVICE = None
|
|
LAYERS = []
|
|
|
|
|
|
def set_device(device):
|
|
global DEVICE
|
|
DEVICE = device
|
|
|
|
|
|
def create_exllama_buffers(max_total_tokens: int):
|
|
global LAYERS, DEVICE
|
|
|
|
# Find the size of the scratch space.
|
|
scratch_bytes = max(
|
|
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
|
|
)
|
|
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
|
|
|
for layer in LAYERS:
|
|
layer.post_init(temp_dq)
|
|
|
|
|
|
class QuantLinear(nn.Module):
|
|
QUANT_TYPE = "exllamav2"
|
|
|
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
|
|
def __init__(
|
|
self,
|
|
weight: Exl2Weight | GPTQWeight,
|
|
bias: torch.Tensor,
|
|
):
|
|
super().__init__()
|
|
|
|
self.q_handle = None
|
|
self.q_tensors = weight
|
|
self.extra_tensors = _ExtraTensors()
|
|
|
|
if isinstance(weight, Exl2Weight):
|
|
self.infeatures = weight.q_invperm.shape[0]
|
|
self.outfeatures = weight.q_weight.shape[1]
|
|
elif isinstance(weight, GPTQWeight):
|
|
if weight.bits != 4:
|
|
raise ValueError(
|
|
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
|
)
|
|
|
|
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
|
self.outfeatures = weight.qweight.shape[1]
|
|
|
|
self.padding = -self.outfeatures % 32
|
|
self.outfeatures = self.outfeatures + self.padding
|
|
|
|
self.device = weight.device
|
|
self.bias = bias if bias is not None else None
|
|
|
|
global LAYERS
|
|
LAYERS.append(self)
|
|
|
|
def post_init(self, temp_dq):
|
|
device = self.q_tensors.device
|
|
assert device.type == "cuda"
|
|
assert device.index is not None
|
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
|
|
|
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
|
# and `Memory access fault by GPU node-2` will EAT you.
|
|
self.temp_dq = temp_dq
|
|
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
|
|
|
def forward(self, x, force_cuda=False):
|
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
|
|
|
if self.bias is not None:
|
|
output.add_(self.bias)
|
|
return output
|
|
|
|
def temp_dq_size(self):
|
|
return self.infeatures * self.outfeatures * 2 + 128
|
|
|
|
def temp_fwd_size(self, max_input_len, max_batch_size):
|
|
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
|
|
|
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
|
|
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
|
|
|
|
|
class ExLlamaV2DeviceTensors:
|
|
|
|
device_idx: int
|
|
scratch_bytes: int
|
|
scratch_idx: int
|
|
scratch: torch.tensor = None
|
|
|
|
def __init__(self, device, scratch_bytes):
|
|
self.device = device
|
|
self.scratch_bytes = scratch_bytes
|
|
|
|
def prepare(self):
|
|
self.scratch = torch.empty(
|
|
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
|
|
)
|
|
|
|
def get_scratch_slice(self, size_bytes):
|
|
|
|
if self.scratch is None:
|
|
self.prepare()
|
|
|
|
size_bytes = ((size_bytes + 127) // 128) * 128
|
|
size_half = size_bytes // 2
|
|
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
|
return scratch_slice
|