Fixing all quantization kernels.

This commit is contained in:
Nicolas Patry 2024-02-09 11:12:12 +00:00
parent 4b524a305c
commit 3ce42ba7ec
7 changed files with 31 additions and 18 deletions

View File

@ -1,5 +1,6 @@
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
bool no_zero,
const cublasHandle_t handle
)
{
int height = x_height;

View File

@ -19,8 +19,8 @@ void q4_matmul_cuda
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
bool no_zero,
cudaStream_t alt_stream
);
void q4_matmul_recons_cuda
@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
bool no_zero,
const cublasHandle_t handle
);
#endif

View File

@ -1,5 +1,6 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View File

@ -183,6 +183,7 @@ void q4_matmul
int x_height = x.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
@ -191,7 +192,9 @@ void q4_matmul
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
(half*) out.data_ptr(),
false,
stream
);
}
else
@ -203,6 +206,7 @@ void q4_matmul
x_height,
wm,
(half*) out.data_ptr(),
false,
at::cuda::getCurrentCUDABlasHandle()
);
}

View File

@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
bool mul_r_weights
)
{
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b->cuda_q_weight,
@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b->cuda_q_weight,

View File

@ -168,8 +168,9 @@ QMatrix::QMatrix
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()
@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>>
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_q_perm,
@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_q_perm,
@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uint32_t* cuda_new_qweight = NULL;
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) {
@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_new_qweight,

View File

@ -407,8 +407,9 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try: