mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix exllama overflows
This commit is contained in:
parent
94d243b3d7
commit
6cb6020e4d
@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
|
@ -189,7 +189,7 @@ __global__ void reconstruct_kernel
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
|
Loading…
Reference in New Issue
Block a user