diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh new file mode 100644 index 00000000..19b1e4a6 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh @@ -0,0 +1,38 @@ +#ifndef _compat_gemm_cuh +#define _compat_gemm_cuh + +#if defined(USE_ROCM) + +// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required +// for symbols as hipblasHalf. +#include + +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index 0ffd2063..351b9cd5 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -18,34 +18,7 @@ #include "q_gemm_kernel.cuh" #include "q_gemm_kernel_gptq.cuh" -#if defined(USE_ROCM) -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} -#define hipblasHgemm __compat_hipblasHgemm - -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_hgemm __compat_hipblasHgemm -#endif +#include "compat_gemm.cuh" void gemm_half_q_half_cuda_part ( diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh index 04643f65..0b899a84 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -1,5 +1,8 @@ #include "compat.cuh" +#include +#include + __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index e166d8e9..6aed7470 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -72,6 +72,8 @@ QMatrix::QMatrix { cudaSetDevice(device); + failed = false; + cuda_q_weight = _q_weight; cuda_q_perm = _q_perm; cuda_q_invperm = _q_invperm; @@ -125,7 +127,15 @@ QMatrix::QMatrix rows_3 = height; rows_2 = height; - if (_gptq_g_idx) make_sequential(_gptq_g_idx); + if (_gptq_g_idx) + { + if (!make_sequential(_gptq_g_idx)) + { + failed = true; + //printf("FAIL\n"); + return; + } + } } // Shuffle quantized data @@ -139,6 +149,9 @@ QMatrix::QMatrix shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); } +QMatrix::~QMatrix() +{ +} // Reconstruct b[k,n] (GPTQ) @@ -437,11 +450,11 @@ void QMatrix::reconstruct(half* out) dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); if (!is_gptq) { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); reconstruct_kernel<<>> ( cuda_q_weight, @@ -464,6 +477,7 @@ void QMatrix::reconstruct(half* out) } else { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); reconstruct_gptq_kernel<<>> ( cuda_q_weight, @@ -523,10 +537,14 @@ __global__ void make_sequential_kernel w_new2[w_new2_row * w2_stride + w2_column] = dst; } -void QMatrix::make_sequential(const uint32_t* cpu_g_idx) +bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) { uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + if (err != cudaSuccess) { + cudaError_t cuda_status = cudaGetLastError(); // Clear error + return false; + } uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); @@ -600,4 +618,6 @@ void QMatrix::make_sequential(const uint32_t* cpu_g_idx) free(cpu_g_idx_map); free(cpu_x_map); free(cpu_x_map_inv); + + return true; } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh index ed713b14..dda83a4f 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -38,6 +38,8 @@ public: half* temp_dq; + bool failed; + QMatrix ( const int _device, @@ -62,7 +64,7 @@ public: ~QMatrix(); void reconstruct(half* out); - void make_sequential(const uint32_t* cpu_g_idx); + bool make_sequential(const uint32_t* cpu_g_idx); private: diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh index 21837961..06a58d18 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -30,3 +30,13 @@ __forceinline__ __device__ float clamp(float x, float a, float b) { return fmaxf(a, fminf(b, x)); } + +#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } +inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 75d2b159..fa831682 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -168,7 +168,7 @@ def serve( # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. - from text_generation_server.utils.gptq.exllama import ( + from text_generation_server.utils.layers import ( create_exllama_buffers, set_device, ) diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index dc8353f3..f546f3af 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -17,10 +17,6 @@ except ImportError: # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") -def _torch_device(idx): - if idx == -1: return "cpu" - return f"cuda:{idx}" - def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) @@ -82,6 +78,53 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, temp_dq) +DEVICE = None +FIXED_BYTES = 0 +LAYERS = [] + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(): + global FIXED_BYTES, LAYERS, DEVICE + temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + + for layer in LAYERS: + layer.post_init(temp_dq) + + + # assert DEVICE is not None, "call set_device first" + + # if ACT_ORDER: + # # TODO: this should be set to rust side `max_total_tokens`, but TGI + # # does not offer an API to expose this variable to python, as this variable + # # is handled by the client but it appears the model is initialized by the server. + # # An alternative could be to initialize the buffers during warmup. + # # Dummy + # max_total_tokens = 2048 + # else: + # max_total_tokens = 1 + + # # This temp_state buffer is required to reorder X in the act-order case. + # temp_state = torch.zeros( + # (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE + # ) + # temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) + + # # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + # prepare_buffers(DEVICE, temp_state, temp_dq) + + # matmul_recons_thd = 8 + # matmul_fused_remap = False + # matmul_no_half2 = False + # set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + # TEMP_STATE, TEMP_DQ = temp_state, temp_dq + + class QuantLinear(nn.Module): QUANT_TYPE = "exllamav2" @@ -98,7 +141,6 @@ class QuantLinear(nn.Module): self.q_handle = None self.q_tensors = None - # self.padding = - outfeatures % 32 # # self.infeatures = infeatures # self.outfeatures = outfeatures + self.padding @@ -112,7 +154,8 @@ class QuantLinear(nn.Module): # assert infeatures % 32 == 0 # assert infeatures % self.group_size == 0 # assert outfeatures % 32 == 0 - # + # self.padding = - outfeatures % 32 + # # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... # self.register_buffer( # 'qweight', @@ -137,13 +180,16 @@ class QuantLinear(nn.Module): self.g_idx = g_idx self.bias = bias if bias is not None else None + global FIXED_BYTES, LAYERS + FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + LAYERS.append(self) + # if bias: # self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) # else: # self.bias = None - # def post_init(self, temp_dq): - temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8)) + def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { @@ -152,7 +198,7 @@ class QuantLinear(nn.Module): "scales":self.scales, "g_idx":self.g_idx } - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8)) + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) self.q_handle = ext_make_q_matrix( self.q_tensors, temp_dq ) @@ -181,12 +227,12 @@ class ExLlamaV2DeviceTensors: scratch_idx: int scratch: torch.tensor = None - def __init__(self, device_idx, scratch_bytes): - self.device_idx = device_idx + def __init__(self, device, scratch_bytes): + self.device = device self.scratch_bytes = scratch_bytes def prepare(self): - self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx)) + self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) def get_scratch_slice(self, size_bytes): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 78f2de8e..13bd422a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -31,16 +31,27 @@ try: major, _minor = torch.cuda.get_device_capability() except Exception: major = 1 + HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False elif CAN_EXLLAMA: try: - from text_generation_server.utils.gptq.exllama import Ex4bitLinear - from text_generation_server.utils.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear + if V2: + from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + HAS_EXLLAMA = "2" + else: + from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + HAS_EXLLAMA = "1" - HAS_EXLLAMA = True except ImportError: pass @@ -309,7 +320,7 @@ def get_linear(weight, bias, quantize): ) if use_exllama: - linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) else: linear = QuantLinear( qweight, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2f330d9c..f03892ba 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -278,7 +278,7 @@ class Weights: ) use_exllama = False else: - logger.info("Using exllama kernels") + logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") if use_exllama: if groupsize >= 0: