mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
254 lines
8.4 KiB
Plaintext
254 lines
8.4 KiB
Plaintext
diff a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py (rejected hunks)
|
|
@@ -1,10 +1,15 @@
|
|
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
|
|
|
+from dataclasses import dataclass
|
|
+from typing import Optional
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from loguru import logger
|
|
|
|
+from text_generation_server.layers.exl2 import Exl2Weight
|
|
+from text_generation_server.layers.gptq import GPTQWeight
|
|
+
|
|
try:
|
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
|
except ImportError:
|
|
@@ -15,6 +20,15 @@ except ImportError:
|
|
none_tensor = torch.empty((1, 1), device="meta")
|
|
|
|
|
|
+@dataclass
|
|
+class _ExtraTensors:
|
|
+ """Additional generated quantizer tensors."""
|
|
+
|
|
+ q_group_map: Optional[torch.Tensor] = None
|
|
+ q_invperm: Optional[torch.Tensor] = None
|
|
+ q_perm: Optional[torch.Tensor] = None
|
|
+
|
|
+
|
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|
"""Matrix multiplication, returns x @ q4"""
|
|
output_shape = x.shape[:-1] + (q4_width,)
|
|
@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|
return output.view(output_shape)
|
|
|
|
|
|
-# Group map needed for irregular group sizes
|
|
-
|
|
-
|
|
-def make_group_map(q_groups, num_qrows):
|
|
-
|
|
+def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
|
gr = q_groups.tolist()
|
|
group_map = []
|
|
num_groups = len(gr) // 2
|
|
@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
|
|
# Create Q matrix
|
|
|
|
|
|
-def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|
+def ext_make_q_matrix(
|
|
+ w: Exl2Weight | GPTQWeight,
|
|
+ extra: _ExtraTensors,
|
|
+ temp_dq,
|
|
+ key: Optional[str] = None,
|
|
+):
|
|
"""
|
|
Create Q matrix
|
|
"""
|
|
# EXL2
|
|
- # won't work as the moment because the tensors are not the same.
|
|
- if "q_weight" in w:
|
|
- w["q_scale_max"] /= 256
|
|
- w["q_perm"] = w["q_perm"].short()
|
|
- w["q_invperm"] = w["q_invperm"].short()
|
|
-
|
|
- if "q_group_map" not in w:
|
|
- w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
|
+ if isinstance(w, Exl2Weight):
|
|
+ extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
|
+ extra.q_perm = torch.argsort(w.q_invperm).short()
|
|
|
|
return make_q_matrix(
|
|
- w["q_weight"],
|
|
- w["q_perm"],
|
|
- w["q_invperm"],
|
|
- w["q_scale"],
|
|
- w["q_scale_max"],
|
|
- w["q_groups"],
|
|
- w["q_group_map"],
|
|
+ w.q_weight,
|
|
+ extra.q_perm,
|
|
+ w.q_invperm,
|
|
+ w.q_scale,
|
|
+ w.q_scale_max,
|
|
+ w.q_groups,
|
|
+ extra.q_group_map,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
temp_dq,
|
|
)
|
|
# GPTQ
|
|
- elif "qweight" in w:
|
|
- if w["scales"].dtype == torch.float:
|
|
- w["scales"] = w["scales"].half()
|
|
+ elif isinstance(w, GPTQWeight):
|
|
+ if w.scales.dtype == torch.float:
|
|
+ w.scales = w.scales.half()
|
|
|
|
# GPTQ with g_idx (act_order)
|
|
- if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
|
- w["q_perm"] = torch.empty(
|
|
- (w["qweight"].shape[0] * 8,),
|
|
+ if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
|
+ extra.q_perm = torch.empty(
|
|
+ (w.qweight.shape[0] * 8,),
|
|
dtype=torch.short,
|
|
- device=w["qweight"].device,
|
|
+ device=w.qweight.device,
|
|
)
|
|
- w["q_invperm"] = torch.empty_like(w["q_perm"])
|
|
+ extra.q_invperm = torch.empty_like(extra.q_perm)
|
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
|
return make_q_matrix(
|
|
- w["qweight"],
|
|
- w["q_perm"],
|
|
- w["q_invperm"],
|
|
+ w.qweight,
|
|
+ extra.q_perm,
|
|
+ extra.q_invperm,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
- w["qzeros"],
|
|
- w["scales"],
|
|
- w["g_idx"].cpu(),
|
|
+ w.qzeros,
|
|
+ w.scales,
|
|
+ w.g_idx.cpu(),
|
|
temp_dq,
|
|
)
|
|
# GPTQ without g_idx
|
|
else:
|
|
return make_q_matrix(
|
|
- w["qweight"],
|
|
+ w.qweight,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
none_tensor,
|
|
- w["qzeros"],
|
|
- w["scales"],
|
|
+ w.qzeros,
|
|
+ w.scales,
|
|
none_tensor,
|
|
temp_dq,
|
|
)
|
|
@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|
|
|
|
|
DEVICE = None
|
|
-FIXED_BYTES = 0
|
|
LAYERS = []
|
|
|
|
|
|
@@ -134,8 +143,13 @@ def set_device(device):
|
|
|
|
|
|
def create_exllama_buffers(max_total_tokens: int):
|
|
- global FIXED_BYTES, LAYERS, DEVICE
|
|
- temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
|
+ global LAYERS, DEVICE
|
|
+
|
|
+ # Find the size of the scratch space.
|
|
+ scratch_bytes = max(
|
|
+ layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
|
|
+ )
|
|
+ temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
|
|
|
for layer in LAYERS:
|
|
layer.post_init(temp_dq)
|
|
@@ -146,49 +160,48 @@ class QuantLinear(nn.Module):
|
|
|
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
|
|
- # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
|
- def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
|
+ def __init__(
|
|
+ self,
|
|
+ weight: Exl2Weight | GPTQWeight,
|
|
+ bias: torch.Tensor,
|
|
+ ):
|
|
super().__init__()
|
|
- if bits != 4:
|
|
- raise ValueError(
|
|
- f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
|
|
- )
|
|
+
|
|
self.q_handle = None
|
|
- self.q_tensors = None
|
|
- self.bits = bits
|
|
- self.maxq = 2**self.bits - 1
|
|
- self.infeatures = qweight.shape[0] // self.bits * 32
|
|
- self.outfeatures = qweight.shape[1]
|
|
+ self.q_tensors = weight
|
|
+ self.extra_tensors = _ExtraTensors()
|
|
+
|
|
+ if isinstance(weight, Exl2Weight):
|
|
+ self.infeatures = weight.q_invperm.shape[0]
|
|
+ self.outfeatures = weight.q_weight.shape[1]
|
|
+ elif isinstance(weight, GPTQWeight):
|
|
+ if weight.bits != 4:
|
|
+ raise ValueError(
|
|
+ f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
|
+ )
|
|
+
|
|
+ self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
|
+ self.outfeatures = weight.qweight.shape[1]
|
|
+
|
|
self.padding = -self.outfeatures % 32
|
|
self.outfeatures = self.outfeatures + self.padding
|
|
|
|
- self.device = qweight.device
|
|
- self.qweight = qweight
|
|
- self.qzeros = qzeros
|
|
- self.scales = scales
|
|
- self.g_idx = g_idx
|
|
+ self.device = weight.device
|
|
self.bias = bias if bias is not None else None
|
|
- self.group_size = groupsize
|
|
|
|
- global FIXED_BYTES, LAYERS
|
|
- FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
|
+ global LAYERS
|
|
LAYERS.append(self)
|
|
|
|
def post_init(self, temp_dq):
|
|
- assert self.qweight.device.type == "cuda"
|
|
- assert self.qweight.device.index is not None
|
|
- self.q_tensors = {
|
|
- "qweight": self.qweight,
|
|
- "qzeros": self.qzeros,
|
|
- "scales": self.scales,
|
|
- "g_idx": self.g_idx,
|
|
- }
|
|
+ device = self.q_tensors.device
|
|
+ assert device.type == "cuda"
|
|
+ assert device.index is not None
|
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
|
|
|
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
|
# and `Memory access fault by GPU node-2` will EAT you.
|
|
self.temp_dq = temp_dq
|
|
- self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
|
+ self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
|
|
|
def forward(self, x, force_cuda=False):
|
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|