diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu index 09126efe..1b0f7956 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -1,5 +1,6 @@ #include "q4_matmul.cuh" #include "column_remap.cuh" +#include #include "../util.cuh" #include "../matrix.cuh" #include "../cu_compat.cuh" @@ -224,8 +225,8 @@ void q4_matmul_recons_cuda const int x_height, Q4Matrix* w, half* out, - const cublasHandle_t handle, - bool no_zero + bool no_zero, + const cublasHandle_t handle ) { int height = x_height; diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh index 63611790..4c7a6669 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -19,8 +19,8 @@ void q4_matmul_cuda const int x_height, const Q4Matrix* w, half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL + bool no_zero, + cudaStream_t alt_stream ); void q4_matmul_recons_cuda @@ -30,8 +30,8 @@ void q4_matmul_recons_cuda const int x_height, Q4Matrix* w, half* out, - const cublasHandle_t handle, - bool no_zero = false + bool no_zero, + const cublasHandle_t handle ); #endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu index 2867a8d0..1f32e6b8 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -1,5 +1,6 @@ // Adapted from turboderp exllama: https://github.com/turboderp/exllama +#include #include "q4_matrix.cuh" #include #include "../util.cuh" @@ -90,7 +91,7 @@ __global__ void make_sequential_kernel int w2_row_shift = w2_subrow << 2; int wnew2_row_shift = i << 2; - uint64_t src = w2[w2_row * w2_stride + w2_column]; + uint64_t src = w2[w2_row * w2_stride + w2_column]; src >>= w2_row_shift; src &= 0x0000000f0000000f; src <<= wnew2_row_shift; @@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); // Replace qweights @@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out) 1 ); - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} \ No newline at end of file + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp index b786988b..f2df80e8 100644 --- a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp +++ b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -183,6 +183,7 @@ void q4_matmul int x_height = x.size(0); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) { q4_matmul_cuda @@ -191,7 +192,9 @@ void q4_matmul (half*) x.data_ptr(), x_height, wm, - (half*) out.data_ptr() + (half*) out.data_ptr(), + false, + stream ); } else @@ -203,6 +206,7 @@ void q4_matmul x_height, wm, (half*) out.data_ptr(), + false, at::cuda::getCurrentCUDABlasHandle() ); } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index b4e4cf22..5b99f1ba 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part bool mul_r_weights ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!b->is_gptq) { dim3 blockDim, gridDim; @@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, @@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part // print_global_mem(r_weights, 1, 1, 1); // DBGI(r_weights_stride); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index 7a0038b4..f7a91e29 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -168,8 +168,9 @@ QMatrix::QMatrix blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); } QMatrix::~QMatrix() @@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out) blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!is_gptq) { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_kernel<<>> + reconstruct_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out) else { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); - reconstruct_gptq_kernel<<>> + reconstruct_gptq_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -563,6 +565,7 @@ __global__ void make_sequential_kernel bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); uint32_t* cuda_new_qweight = NULL; cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); if (err != cudaSuccess) { @@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 8; - make_sequential_kernel<<>> + make_sequential_kernel<<>> ( cuda_q_weight, cuda_new_qweight, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 8f7e1f10..d0614346 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -407,8 +407,9 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] + # Order is important here, desc_act is missing on some real models self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: