mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix exllama overflows
This commit is contained in:
parent
94d243b3d7
commit
6cb6020e4d
@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel
|
|||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel
|
|||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel
|
|||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
|
@ -189,7 +189,7 @@ __global__ void reconstruct_kernel
|
|||||||
int group = row / groupsize;
|
int group = row / groupsize;
|
||||||
|
|
||||||
half w_scale = w_scales_.item(group, column);
|
half w_scale = w_scales_.item(group, column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
|
||||||
|
|
||||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||||
half* out_ptr = out_.item_ptr(row, column);
|
half* out_ptr = out_.item_ptr(row, column);
|
||||||
|
@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
// __syncthreads();
|
// __syncthreads();
|
||||||
|
|
||||||
@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel
|
|||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel
|
|||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
|
Loading…
Reference in New Issue
Block a user