mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
651 lines
19 KiB
Plaintext
651 lines
19 KiB
Plaintext
#include "q_matrix.cuh"
|
|
#include "matrix_view.cuh"
|
|
#include "util.cuh"
|
|
|
|
#include "quant/qdq_2.cuh"
|
|
#include "quant/qdq_3.cuh"
|
|
#include "quant/qdq_4.cuh"
|
|
#include "quant/qdq_5.cuh"
|
|
#include "quant/qdq_6.cuh"
|
|
#include "quant/qdq_8.cuh"
|
|
|
|
#define BLOCK_KN_SIZE 128
|
|
|
|
#define THREADS_X 32
|
|
#define THREADS_Y 32
|
|
|
|
// Shuffle quantized data on load
|
|
|
|
__global__ void shuffle_kernel
|
|
(
|
|
uint32_t* __restrict__ b_q_weight,
|
|
const int size_k,
|
|
const int size_n,
|
|
const int rows_8,
|
|
const int rows_6,
|
|
const int rows_5,
|
|
const int rows_4,
|
|
const int rows_3,
|
|
const int rows_2
|
|
)
|
|
{
|
|
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
|
if (n >= size_n) return;
|
|
int k = 0;
|
|
uint32_t* b_ptr = b_q_weight + n;
|
|
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
|
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
|
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
|
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
|
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
|
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
|
}
|
|
|
|
|
|
// QMatrix constructor
|
|
|
|
QMatrix::QMatrix
|
|
(
|
|
const int _device,
|
|
const int _height,
|
|
const int _width,
|
|
const int _groups,
|
|
|
|
uint32_t* _q_weight,
|
|
uint16_t* _q_perm,
|
|
uint16_t* _q_invperm,
|
|
uint32_t* _q_scale,
|
|
half* _q_scale_max,
|
|
uint16_t* _q_groups,
|
|
uint16_t* _q_group_map,
|
|
|
|
uint32_t* _gptq_qzeros,
|
|
half* _gptq_scales,
|
|
uint32_t* _gptq_g_idx,
|
|
|
|
half* _temp_dq
|
|
) :
|
|
device(_device),
|
|
height(_height),
|
|
width(_width),
|
|
groups(_groups),
|
|
temp_dq(_temp_dq)
|
|
{
|
|
cudaSetDevice(device);
|
|
|
|
failed = false;
|
|
|
|
cuda_q_weight = _q_weight;
|
|
cuda_q_perm = _q_perm;
|
|
cuda_q_invperm = _q_invperm;
|
|
cuda_q_scale = _q_scale;
|
|
cuda_q_scale_max = _q_scale_max;
|
|
cuda_q_groups = _q_groups;
|
|
cuda_q_group_map = _q_group_map;
|
|
cuda_gptq_qzeros = _gptq_qzeros;
|
|
cuda_gptq_scales = _gptq_scales;
|
|
|
|
is_gptq = (_gptq_qzeros != NULL);
|
|
|
|
if (is_gptq)
|
|
{
|
|
gptq_groupsize = 1;
|
|
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
|
|
}
|
|
|
|
// Create group map
|
|
|
|
rows_8 = 0;
|
|
rows_6 = 0;
|
|
rows_5 = 0;
|
|
rows_4 = 0;
|
|
rows_3 = 0;
|
|
rows_2 = 0;
|
|
|
|
if (!is_gptq)
|
|
{
|
|
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
|
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
|
|
|
int row = 0;
|
|
for (int i = 0; i < groups; i++)
|
|
{
|
|
int bits = cpu_q_groups[i * 2];
|
|
|
|
int rows;
|
|
if (i < groups - 1)
|
|
{
|
|
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
|
|
rows = qrows * 32 / bits;
|
|
}
|
|
else rows = height - row;
|
|
|
|
if (bits == 8) rows_8 += rows;
|
|
if (bits == 6) rows_6 += rows;
|
|
if (bits == 5) rows_5 += rows;
|
|
if (bits == 4) rows_4 += rows;
|
|
if (bits == 3) rows_3 += rows;
|
|
if (bits == 2) rows_2 += rows;
|
|
row += rows;
|
|
}
|
|
|
|
free(cpu_q_groups);
|
|
|
|
rows_6 += rows_8;
|
|
rows_5 += rows_6;
|
|
rows_4 += rows_5;
|
|
rows_3 += rows_4;
|
|
rows_2 += rows_3;
|
|
}
|
|
else
|
|
{
|
|
rows_4 = height;
|
|
rows_3 = height;
|
|
rows_2 = height;
|
|
|
|
if (_gptq_g_idx)
|
|
{
|
|
if (!make_sequential(_gptq_g_idx))
|
|
{
|
|
failed = true;
|
|
//printf("FAIL\n");
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// DBGI(rows_8);
|
|
// DBGI(rows_6);
|
|
// DBGI(rows_5);
|
|
// DBGI(rows_4);
|
|
// DBGI(rows_3);
|
|
// DBGI(rows_2);
|
|
|
|
// Shuffle quantized data
|
|
|
|
dim3 blockDim, gridDim;
|
|
blockDim.x = THREADS_X;
|
|
blockDim.y = 1;
|
|
gridDim.x = DIVIDE(width, THREADS_X);
|
|
gridDim.y = 1;
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
|
}
|
|
|
|
QMatrix::~QMatrix()
|
|
{
|
|
}
|
|
|
|
// Reconstruct b[k,n] (GPTQ)
|
|
|
|
__global__ void reconstruct_gptq_kernel
|
|
(
|
|
const uint32_t* __restrict__ b_q_weight,
|
|
const uint16_t* __restrict__ b_q_perm,
|
|
const uint32_t* __restrict__ b_gptq_qzeros,
|
|
const half* __restrict__ b_gptq_scales,
|
|
//const uint16_t* __restrict__ b_q_groups,
|
|
const int size_k,
|
|
const int size_n,
|
|
const int groupsize,
|
|
const int groups,
|
|
half* __restrict__ b,
|
|
const int rows_4
|
|
)
|
|
{
|
|
MatrixView_half_rw b_(b, size_k, size_n);
|
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
|
|
|
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
|
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
|
|
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
|
|
|
// Preload remapping table
|
|
|
|
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
|
int t = threadIdx.x;
|
|
|
|
if (b_q_perm)
|
|
{
|
|
if (offset_k + t < size_k)
|
|
perm[t] = b_q_perm[offset_k + t];
|
|
}
|
|
|
|
// Column
|
|
|
|
int n = offset_n + t * 4;
|
|
if (n >= size_n) return;
|
|
|
|
// Find initial group
|
|
|
|
int group = offset_k / groupsize;
|
|
int nextgroup = offset_k + groupsize;
|
|
|
|
// b offset
|
|
|
|
int qk = offset_k / (32 / 4);
|
|
|
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
|
|
|
// Initial zeros/scale
|
|
|
|
int zeros[4];
|
|
half2 scales[4];
|
|
half2 z1z16[4][2];
|
|
half2 y1y16[4][2];
|
|
b_gptq_qzeros_.item4(zeros, group, n);
|
|
b_gptq_scales_.item4_h2(scales, group, n);
|
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
|
|
|
__syncthreads();
|
|
|
|
int k = offset_k;
|
|
int lk = 0;
|
|
|
|
while (k < end_k)
|
|
{
|
|
if (k == nextgroup)
|
|
{
|
|
group++;
|
|
nextgroup += groupsize;
|
|
b_gptq_qzeros_.item4(zeros, group, n);
|
|
b_gptq_scales_.item4_h2(scales, group, n);
|
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
|
}
|
|
|
|
for (int p = 0; p < 4; p++)
|
|
{
|
|
half2 dq[4][4];
|
|
const int4* b_ptr4 = (int4*) b_ptr;
|
|
int4 load_int4 = *b_ptr4;
|
|
|
|
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
|
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
|
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
|
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
|
|
|
b_ptr += size_n;
|
|
//half* dqh = (half*)dq;
|
|
if (b_q_perm)
|
|
{
|
|
for (int j = 0; j < 4; j++)
|
|
{
|
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
|
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
|
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for (int j = 0; j < 4; j++)
|
|
{
|
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
|
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
|
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
|
}
|
|
}
|
|
}
|
|
k += 32;
|
|
}
|
|
}
|
|
|
|
|
|
// Reconstruct b[k,n]
|
|
|
|
__global__ void reconstruct_kernel
|
|
(
|
|
const uint32_t* __restrict__ b_q_weight,
|
|
const uint16_t* __restrict__ b_q_perm,
|
|
const uint32_t* __restrict__ b_q_scale,
|
|
const half* __restrict__ b_q_scale_max,
|
|
const uint16_t* __restrict__ b_q_group_map,
|
|
const int size_k,
|
|
const int size_n,
|
|
//const int groupsize,
|
|
const int groups,
|
|
half* __restrict__ b,
|
|
const int rows_8,
|
|
const int rows_6,
|
|
const int rows_5,
|
|
const int rows_4,
|
|
const int rows_3,
|
|
const int rows_2
|
|
)
|
|
{
|
|
MatrixView_half_rw b_(b, size_k, size_n);
|
|
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
|
|
|
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
|
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
|
|
|
// Preload remapping table
|
|
|
|
int t = threadIdx.x;
|
|
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
|
if (offset_k + t < size_k)
|
|
perm[t] = b_q_perm[offset_k + t];
|
|
|
|
// Column
|
|
|
|
int n = offset_n + t;
|
|
if (n >= size_n) return;
|
|
|
|
// Find initial group
|
|
|
|
// int group = offset_k / groupsize;
|
|
int group = b_q_group_map[offset_k * 2];
|
|
|
|
int pre_rows_8 = min(rows_8, offset_k);
|
|
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
|
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
|
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
|
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
|
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
|
int qk = 0;
|
|
qk += pre_rows_8 / 32 * 8;
|
|
qk += pre_rows_6 / 32 * 6;
|
|
qk += pre_rows_5 / 32 * 5;
|
|
qk += pre_rows_4 / 32 * 4;
|
|
qk += pre_rows_3 / 32 * 3;
|
|
qk += pre_rows_2 / 32 * 2;
|
|
|
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
|
|
|
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
|
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
|
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
|
|
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
|
int k = offset_k;
|
|
int lk = 0;
|
|
|
|
__syncthreads();
|
|
|
|
while (k < rows_8 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 4; p++)
|
|
{
|
|
half2 dq[4];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
|
dequant_8bit_8(q_0, q_1, dq, size_n);
|
|
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 32;
|
|
}
|
|
|
|
while (k < rows_6 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 2; p++)
|
|
{
|
|
half2 dq[8];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
|
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
|
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 32;
|
|
}
|
|
|
|
while (k < rows_5 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 1; p++)
|
|
{
|
|
half2 dq[16];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
|
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
|
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 32;
|
|
}
|
|
|
|
while (k < rows_4 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 4; p++)
|
|
{
|
|
half2 dq[4];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
dequant_4bit_8(q_0, dq, size_n);
|
|
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 32;
|
|
}
|
|
|
|
while (k < rows_3 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 1; p++)
|
|
{
|
|
half2 dq[16];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
|
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
|
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 32;
|
|
}
|
|
|
|
while (k < rows_2 && k < end_k)
|
|
{
|
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
|
for (int p = 0; p < 1; p++)
|
|
{
|
|
half2 dq[8];
|
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
|
dequant_2bit_16(q_0, dq, size_n);
|
|
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
|
half* dqh = (half*) dq;
|
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
|
}
|
|
k += 16;
|
|
}
|
|
}
|
|
|
|
void QMatrix::reconstruct(half* out)
|
|
{
|
|
dim3 blockDim, gridDim;
|
|
blockDim.x = BLOCK_KN_SIZE;
|
|
blockDim.y = 1;
|
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
if (!is_gptq)
|
|
{
|
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
|
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
(
|
|
cuda_q_weight,
|
|
cuda_q_perm,
|
|
cuda_q_scale,
|
|
cuda_q_scale_max,
|
|
cuda_q_group_map,
|
|
height,
|
|
width,
|
|
//groupsize,
|
|
groups,
|
|
out,
|
|
rows_8,
|
|
rows_6,
|
|
rows_5,
|
|
rows_4,
|
|
rows_3,
|
|
rows_2
|
|
);
|
|
}
|
|
else
|
|
{
|
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
|
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
(
|
|
cuda_q_weight,
|
|
cuda_q_perm,
|
|
cuda_gptq_qzeros,
|
|
cuda_gptq_scales,
|
|
//const uint16_t* __restrict__ b_q_groups,
|
|
height,
|
|
width,
|
|
gptq_groupsize,
|
|
groups,
|
|
out,
|
|
rows_4
|
|
);
|
|
}
|
|
}
|
|
|
|
__global__ void make_sequential_kernel
|
|
(
|
|
const uint32_t* __restrict__ w,
|
|
uint32_t* __restrict__ w_new,
|
|
const uint16_t* __restrict__ q_perm,
|
|
const int w_height,
|
|
const int w_width
|
|
)
|
|
{
|
|
const uint64_t* w2 = (uint64_t*) w;
|
|
uint64_t* w_new2 = (uint64_t*) w_new;
|
|
int w2_stride = w_width >> 1;
|
|
|
|
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
|
if (w2_column >= w2_stride) return;
|
|
|
|
int w_new2_row = blockIdx.y;
|
|
|
|
int q_perm_idx = w_new2_row << 3;
|
|
|
|
uint64_t dst = 0;
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < 8; i++)
|
|
{
|
|
int source_row = q_perm[q_perm_idx++];
|
|
|
|
int w2_row = source_row >> 3;
|
|
int w2_subrow = source_row & 0x07;
|
|
int w2_row_shift = w2_subrow << 2;
|
|
int wnew2_row_shift = i << 2;
|
|
|
|
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
|
src >>= w2_row_shift;
|
|
src &= 0x0000000f0000000f;
|
|
src <<= wnew2_row_shift;
|
|
dst |= src;
|
|
}
|
|
|
|
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
|
}
|
|
|
|
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
|
{
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
uint32_t* cuda_new_qweight = NULL;
|
|
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
|
if (err != cudaSuccess) {
|
|
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
|
return false;
|
|
}
|
|
|
|
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
|
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
|
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
|
|
|
// Group histogram
|
|
|
|
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
|
|
|
// Group map
|
|
|
|
for (int i = 0, acc = 0; i < groups; i++)
|
|
{
|
|
short tmp = cpu_g_idx_map[i];
|
|
cpu_g_idx_map[i] = acc;
|
|
acc += tmp;
|
|
}
|
|
|
|
// X map (inverse)
|
|
|
|
for (int row = 0; row < height; row++)
|
|
{
|
|
uint32_t target_group = cpu_g_idx[row];
|
|
uint32_t target_row = cpu_g_idx_map[target_group];
|
|
cpu_g_idx_map[target_group]++;
|
|
cpu_x_map_inv[row] = target_row;
|
|
}
|
|
|
|
// X map
|
|
|
|
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
|
|
|
// Reduce to uint16_t
|
|
|
|
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
|
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
|
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
|
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
|
|
|
// Move to CUDA
|
|
|
|
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
|
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
|
|
|
// Rearrange rows in w
|
|
|
|
dim3 blockDim, gridDim;
|
|
blockDim.x = THREADS_X;
|
|
blockDim.y = 1;
|
|
gridDim.x = DIVIDE(width, THREADS_X);
|
|
gridDim.y = height / 8;
|
|
|
|
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
(
|
|
cuda_q_weight,
|
|
cuda_new_qweight,
|
|
cuda_q_perm,
|
|
height / 8,
|
|
width
|
|
);
|
|
|
|
// Replace qweights
|
|
|
|
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
|
|
|
// Cleanup
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
cudaFree(cuda_new_qweight);
|
|
free(cpu_g_idx_map);
|
|
free(cpu_x_map);
|
|
free(cpu_x_map_inv);
|
|
|
|
return true;
|
|
}
|