try-catch to load the cuda extension, quite ugly practice tbh

This commit is contained in:
Felix Marty 2023-07-05 17:53:56 +00:00
parent 620ed7d8aa
commit a6e387404d

View File

@ -4,6 +4,15 @@ import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
import torch
from loguru import logger
try:
from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e:
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
try:
import triton
import triton.language as tl
@ -359,9 +368,6 @@ class QuantLinear(nn.Module):
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
import torch
from custom_kernels.exllama import make_q4, q4_matmul
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")