From cdd120ac022f17819be2c440d5efa84d8dafe7d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Jun 2024 10:45:47 +0200 Subject: [PATCH] Do not initialize scratch space when there are no ExLlamaV2 layers (#2015) # What does this PR do? Do not attempt to allocate ExLlamaV2 scratch buffers when there are no ExLlama2 layers. Avoids a crash in warmup for models that cannot use exllama when ExLlamaV2 is installed. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../layers/gptq/exllamav2.py | 159 ++++++++++-------- 1 file changed, 89 insertions(+), 70 deletions(-) diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 321ced97..4d45822b 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -1,10 +1,15 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 +from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn from loguru import logger +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight + try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: @@ -15,6 +20,15 @@ except ImportError: none_tensor = torch.empty((1, 1), device="meta") +@dataclass +class _ExtraTensors: + """Additional generated quantizer tensors.""" + + q_group_map: Optional[torch.Tensor] = None + q_invperm: Optional[torch.Tensor] = None + q_perm: Optional[torch.Tensor] = None + + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) @@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): return output.view(output_shape) -# Group map needed for irregular group sizes - - -def make_group_map(q_groups, num_qrows): - +def make_group_map(q_groups: torch.Tensor, num_qrows: int): gr = q_groups.tolist() group_map = [] num_groups = len(gr) // 2 @@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): # Create Q matrix -def ext_make_q_matrix(w: dict, temp_dq, key: str = None): +def ext_make_q_matrix( + w: Exl2Weight | GPTQWeight, + extra: _ExtraTensors, + temp_dq, + key: Optional[str] = None, +): """ Create Q matrix """ # EXL2 - # won't work as the moment because the tensors are not the same. - if "q_weight" in w: - w["q_scale_max"] /= 256 - w["q_perm"] = w["q_perm"].short() - w["q_invperm"] = w["q_invperm"].short() - - if "q_group_map" not in w: - w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + if isinstance(w, Exl2Weight): + extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) + extra.q_perm = torch.argsort(w.q_invperm).short() return make_q_matrix( - w["q_weight"], - w["q_perm"], - w["q_invperm"], - w["q_scale"], - w["q_scale_max"], - w["q_groups"], - w["q_group_map"], + w.q_weight, + extra.q_perm, + w.q_invperm, + w.q_scale, + w.q_scale_max, + w.q_groups, + extra.q_group_map, none_tensor, none_tensor, none_tensor, temp_dq, ) # GPTQ - elif "qweight" in w: - if w["scales"].dtype == torch.float: - w["scales"] = w["scales"].half() + elif isinstance(w, GPTQWeight): + if w.scales.dtype == torch.float: + w.scales = w.scales.half() # GPTQ with g_idx (act_order) - if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): - w["q_perm"] = torch.empty( - (w["qweight"].shape[0] * 8,), + if w.g_idx is not None and not (w.g_idx == 0).all().item(): + extra.q_perm = torch.empty( + (w.qweight.shape[0] * 8,), dtype=torch.short, - device=w["qweight"].device, + device=w.qweight.device, ) - w["q_invperm"] = torch.empty_like(w["q_perm"]) + extra.q_invperm = torch.empty_like(extra.q_perm) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. return make_q_matrix( - w["qweight"], - w["q_perm"], - w["q_invperm"], + w.qweight, + extra.q_perm, + extra.q_invperm, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], - w["g_idx"].cpu(), + w.qzeros, + w.scales, + w.g_idx.cpu(), temp_dq, ) # GPTQ without g_idx else: return make_q_matrix( - w["qweight"], + w.qweight, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, none_tensor, - w["qzeros"], - w["scales"], + w.qzeros, + w.scales, none_tensor, temp_dq, ) @@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): DEVICE = None -FIXED_BYTES = 0 LAYERS = [] @@ -134,8 +143,19 @@ def set_device(device): def create_exllama_buffers(max_total_tokens: int): - global FIXED_BYTES, LAYERS, DEVICE - temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + global LAYERS, DEVICE + + # No need to initialize scratch space if there are no layers + # that use ExLLamav2. + if len(LAYERS) == 0: + return + + # Find the size of the scratch space. + scratch_bytes = max( + layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) + for layer in LAYERS + ) + temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) for layer in LAYERS: layer.post_init(temp_dq) @@ -146,49 +166,48 @@ class QuantLinear(nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + def __init__( + self, + weight: Exl2Weight | GPTQWeight, + bias: torch.Tensor, + ): super().__init__() - if bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." - ) + self.q_handle = None - self.q_tensors = None - self.bits = bits - self.maxq = 2**self.bits - 1 - self.infeatures = qweight.shape[0] // self.bits * 32 - self.outfeatures = qweight.shape[1] + self.q_tensors = weight + self.extra_tensors = _ExtraTensors() + + if isinstance(weight, Exl2Weight): + self.infeatures = weight.q_invperm.shape[0] + self.outfeatures = weight.q_weight.shape[1] + elif isinstance(weight, GPTQWeight): + if weight.bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." + ) + + self.infeatures = weight.qweight.shape[0] // weight.bits * 32 + self.outfeatures = weight.qweight.shape[1] + self.padding = -self.outfeatures % 32 self.outfeatures = self.outfeatures + self.padding - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx + self.device = weight.device self.bias = bias if bias is not None else None - self.group_size = groupsize - global FIXED_BYTES, LAYERS - FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + global LAYERS LAYERS.append(self) def post_init(self, temp_dq): - assert self.qweight.device.type == "cuda" - assert self.qweight.device.index is not None - self.q_tensors = { - "qweight": self.qweight, - "qzeros": self.qzeros, - "scales": self.scales, - "g_idx": self.g_idx, - } + device = self.q_tensors.device + assert device.type == "cuda" + assert device.index is not None temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, # and `Memory access fault by GPU node-2` will EAT you. self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) @@ -203,7 +222,7 @@ class QuantLinear(nn.Module): def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): + def scratch_space_fixed(self, max_input_len, max_batch_size): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)