From a6e387404d92e219c648599b7619ebc159c3998a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 5 Jul 2023 17:53:56 +0000 Subject: [PATCH] try-catch to load the cuda extension, quite ugly practice tbh --- .../utils/gptq/quant_linear.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index b01788bb..f7c06e0f 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -4,6 +4,15 @@ import torch import torch.nn as nn from torch.cuda.amp import custom_bwd, custom_fwd +import torch + +from loguru import logger + +try: + from custom_kernels.exllama import make_q4, q4_matmul +except Exception as e: + logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}") + try: import triton import triton.language as tl @@ -359,9 +368,6 @@ class QuantLinear(nn.Module): out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) -import torch -from custom_kernels.exllama import make_q4, q4_matmul - # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device = "meta")