mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing all quantization kernels.
This commit is contained in:
parent
4b524a305c
commit
3ce42ba7ec
@ -1,5 +1,6 @@
|
|||||||
#include "q4_matmul.cuh"
|
#include "q4_matmul.cuh"
|
||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
#include "../matrix.cuh"
|
#include "../matrix.cuh"
|
||||||
#include "../cu_compat.cuh"
|
#include "../cu_compat.cuh"
|
||||||
@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
|
|||||||
const int x_height,
|
const int x_height,
|
||||||
Q4Matrix* w,
|
Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
const cublasHandle_t handle,
|
bool no_zero,
|
||||||
bool no_zero
|
const cublasHandle_t handle
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
int height = x_height;
|
int height = x_height;
|
||||||
|
@ -19,8 +19,8 @@ void q4_matmul_cuda
|
|||||||
const int x_height,
|
const int x_height,
|
||||||
const Q4Matrix* w,
|
const Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
bool no_zero = false,
|
bool no_zero,
|
||||||
cudaStream_t alt_stream = NULL
|
cudaStream_t alt_stream
|
||||||
);
|
);
|
||||||
|
|
||||||
void q4_matmul_recons_cuda
|
void q4_matmul_recons_cuda
|
||||||
@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
|
|||||||
const int x_height,
|
const int x_height,
|
||||||
Q4Matrix* w,
|
Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
const cublasHandle_t handle,
|
bool no_zero,
|
||||||
bool no_zero = false
|
const cublasHandle_t handle
|
||||||
);
|
);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include "q4_matrix.cuh"
|
#include "q4_matrix.cuh"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
|
|||||||
int w2_row_shift = w2_subrow << 2;
|
int w2_row_shift = w2_subrow << 2;
|
||||||
int wnew2_row_shift = i << 2;
|
int wnew2_row_shift = i << 2;
|
||||||
|
|
||||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||||
src >>= w2_row_shift;
|
src >>= w2_row_shift;
|
||||||
src &= 0x0000000f0000000f;
|
src &= 0x0000000f0000000f;
|
||||||
src <<= wnew2_row_shift;
|
src <<= wnew2_row_shift;
|
||||||
@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
|||||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||||
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
||||||
|
|
||||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||||
|
|
||||||
// Replace qweights
|
// Replace qweights
|
||||||
|
|
||||||
@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
|
|||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
}
|
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||||
|
}
|
||||||
|
@ -183,6 +183,7 @@ void q4_matmul
|
|||||||
|
|
||||||
int x_height = x.size(0);
|
int x_height = x.size(0);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||||
{
|
{
|
||||||
q4_matmul_cuda
|
q4_matmul_cuda
|
||||||
@ -191,7 +192,9 @@ void q4_matmul
|
|||||||
(half*) x.data_ptr(),
|
(half*) x.data_ptr(),
|
||||||
x_height,
|
x_height,
|
||||||
wm,
|
wm,
|
||||||
(half*) out.data_ptr()
|
(half*) out.data_ptr(),
|
||||||
|
false,
|
||||||
|
stream
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -203,6 +206,7 @@ void q4_matmul
|
|||||||
x_height,
|
x_height,
|
||||||
wm,
|
wm,
|
||||||
(half*) out.data_ptr(),
|
(half*) out.data_ptr(),
|
||||||
|
false,
|
||||||
at::cuda::getCurrentCUDABlasHandle()
|
at::cuda::getCurrentCUDABlasHandle()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
bool mul_r_weights
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (!b->is_gptq)
|
if (!b->is_gptq)
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
|
|
||||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
b->cuda_q_weight,
|
b->cuda_q_weight,
|
||||||
@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
// print_global_mem(r_weights, 1, 1, 1);
|
// print_global_mem(r_weights, 1, 1, 1);
|
||||||
// DBGI(r_weights_stride);
|
// DBGI(r_weights_stride);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
b->cuda_q_weight,
|
b->cuda_q_weight,
|
||||||
|
@ -168,8 +168,9 @@ QMatrix::QMatrix
|
|||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = 1;
|
gridDim.y = 1;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
QMatrix::~QMatrix()
|
QMatrix::~QMatrix()
|
||||||
@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
|
|||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
if (!is_gptq)
|
if (!is_gptq)
|
||||||
{
|
{
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
|
|||||||
|
|
||||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
{
|
{
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
uint32_t* cuda_new_qweight = NULL;
|
uint32_t* cuda_new_qweight = NULL;
|
||||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
|||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = height / 8;
|
gridDim.y = height / 8;
|
||||||
|
|
||||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_new_qweight,
|
cuda_new_qweight,
|
||||||
|
@ -407,8 +407,9 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
# Order is important here, desc_act is missing on some real models
|
||||||
self.quant_method = data["quantization_config"]["quant_method"]
|
self.quant_method = data["quantization_config"]["quant_method"]
|
||||||
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user