Fixing all quantization kernels.

This commit is contained in:
Nicolas Patry 2024-02-09 11:12:12 +00:00
parent 4b524a305c
commit 3ce42ba7ec
7 changed files with 31 additions and 18 deletions

View File

@ -1,5 +1,6 @@
#include "q4_matmul.cuh" #include "q4_matmul.cuh"
#include "column_remap.cuh" #include "column_remap.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "../util.cuh" #include "../util.cuh"
#include "../matrix.cuh" #include "../matrix.cuh"
#include "../cu_compat.cuh" #include "../cu_compat.cuh"
@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
const int x_height, const int x_height,
Q4Matrix* w, Q4Matrix* w,
half* out, half* out,
const cublasHandle_t handle, bool no_zero,
bool no_zero const cublasHandle_t handle
) )
{ {
int height = x_height; int height = x_height;

View File

@ -19,8 +19,8 @@ void q4_matmul_cuda
const int x_height, const int x_height,
const Q4Matrix* w, const Q4Matrix* w,
half* out, half* out,
bool no_zero = false, bool no_zero,
cudaStream_t alt_stream = NULL cudaStream_t alt_stream
); );
void q4_matmul_recons_cuda void q4_matmul_recons_cuda
@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
const int x_height, const int x_height,
Q4Matrix* w, Q4Matrix* w,
half* out, half* out,
const cublasHandle_t handle, bool no_zero,
bool no_zero = false const cublasHandle_t handle
); );
#endif #endif

View File

@ -1,5 +1,6 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama // Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh" #include "q4_matrix.cuh"
#include <vector> #include <vector>
#include "../util.cuh" #include "../util.cuh"
@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
int w2_row_shift = w2_subrow << 2; int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2; int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column]; uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift; src >>= w2_row_shift;
src &= 0x0000000f0000000f; src &= 0x0000000f0000000f;
src <<= wnew2_row_shift; src <<= wnew2_row_shift;
@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights // Replace qweights
@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
1 1
); );
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
} }

View File

@ -183,6 +183,7 @@ void q4_matmul
int x_height = x.size(0); int x_height = x.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{ {
q4_matmul_cuda q4_matmul_cuda
@ -191,7 +192,9 @@ void q4_matmul
(half*) x.data_ptr(), (half*) x.data_ptr(),
x_height, x_height,
wm, wm,
(half*) out.data_ptr() (half*) out.data_ptr(),
false,
stream
); );
} }
else else
@ -203,6 +206,7 @@ void q4_matmul
x_height, x_height,
wm, wm,
(half*) out.data_ptr(), (half*) out.data_ptr(),
false,
at::cuda::getCurrentCUDABlasHandle() at::cuda::getCurrentCUDABlasHandle()
); );
} }

View File

@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
bool mul_r_weights bool mul_r_weights
) )
{ {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!b->is_gptq) if (!b->is_gptq)
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
a, a,
b->cuda_q_weight, b->cuda_q_weight,
@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1); // print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride); // DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
a, a,
b->cuda_q_weight, b->cuda_q_weight,

View File

@ -168,8 +168,9 @@ QMatrix::QMatrix
blockDim.y = 1; blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1; gridDim.y = 1;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
} }
QMatrix::~QMatrix() QMatrix::~QMatrix()
@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!is_gptq) if (!is_gptq)
{ {
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>> reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_q_perm, cuda_q_perm,
@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
else else
{ {
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>> reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_q_perm, cuda_q_perm,
@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{ {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uint32_t* cuda_new_qweight = NULL; uint32_t* cuda_new_qweight = NULL;
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) { if (err != cudaSuccess) {
@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8; gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>> make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_new_qweight, cuda_new_qweight,

View File

@ -407,8 +407,9 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] # Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"] self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try: