Merge branch 'huggingface:main' into main

This commit is contained in:
Lukasz Olszewski 2023-12-22 14:38:26 +01:00 committed by GitHub
commit b223ac70b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 535 additions and 256 deletions

View File

@ -39,7 +39,7 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
## Supported Hardware
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Quantization (GPTQ, AWQ, etc.)
@ -47,5 +47,5 @@ TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with pag
* Kernel for slinding window attention (Mistral)
TGI is also supported on the following AI hardware accelerators:
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.

View File

@ -2,6 +2,7 @@
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
#define QMODE_2BIT 1
#define QMODE_3BIT 1
@ -10,4 +11,5 @@
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#endif

View File

@ -10,16 +10,19 @@
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define GPTQ_BLOCK_KN_SIZE 128
#define GPTQ_BLOCK_M_SIZE_MAX 8
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
#define EXL2_BLOCK_KN_SIZE 64
#define EXL2_BLOCK_M_SIZE_MAX 8
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#include "compat_gemm.cuh"
void gemm_half_q_half_cuda_part
(
const half* a,
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
int size_n,
int size_k,
int m_count,
bool clear
bool clear,
const half* r_weights,
int r_weights_stride,
bool mul_r_weights
)
{
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.x = EXL2_BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>>
(
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_group_map,
b->cuda_q_perm,
b->rows_8,
b->rows_6,
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
b->rows_4,
b->rows_3,
b->rows_2,
clear
clear,
r_weights,
r_weights_stride
);
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.x = GPTQ_BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGI(b->rows_4);
// DBGI(b->height);
// DBGX((uint64_t) r_weights);
// if (r_weights)
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>>
(
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
size_n,
size_k,
b->groups,
b->groupsize,
b->gptq_groupsize,
b->cuda_q_perm,
b->rows_4,
clear
clear,
r_weights,
r_weights_stride
);
}
}
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
int size_k,
bool clear,
half* temp_dq,
bool force_cuda
bool force_cuda,
const half* r_weights,
const int r_weights_stride,
bool mul_r_weights
)
{
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq;
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
}
else
{
//printf("cuda\n");
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
int max_chunks = size_m / block_m_size_max;
int last_chunk = max_chunks * block_m_size_max;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
}
}
}
@ -201,11 +210,10 @@ void clear_tensor_cuda
int size_n
)
{
return;
dim3 blockDim, gridDim;
blockDim.x = CLEAR_N_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.y = size_m;
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
// dim3 blockDim, gridDim;
// blockDim.x = CLEAR_N_SIZE;
// blockDim.y = 1;
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
// gridDim.y = size_m;
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
}

View File

@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
bool force_cuda = false,
const half* r_weights = NULL,
const int r_weights_stride = 0,
bool mul_r_weights = false
);
void clear_tensor_cuda

View File

@ -1,8 +1,5 @@
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int,
const int,
const int,
const int,
const uint16_t*,
const uint16_t*,
const int,
const int,
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
const int,
const int,
const int,
const bool
const bool,
const half*,
const int
);
template <bool first_block, int m_count>
template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_group_map,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
const bool clear,
const half* r_weights,
const int r_weights_stride
)
{
MatrixView_half a_(a, size_m, size_k);
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
// half a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
// Find initial group
int group = offset_k / groupsize;
//int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
// if (offset_m == 0 && t == 0)
// DBGI2(offset_k, group);
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
int temp_k = offset_k;
for (int g = 0; temp_k < end_k; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
half maxscale = b_q_scale_max[group + g];
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
temp_k += b_q_group_map[temp_k * 2 + 1];
}
// a, b offset
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
int a_stride = EXL2_BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
half qs_h0 = scales[scales_idx][0];
half qs_h1 = scales[scales_idx][1];
half qs_h2 = scales[scales_idx][2];
half qs_h3 = scales[scales_idx][3];
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
// Column result
float block_c[m_count][4] = {};
half block_c[m_count][4] = {};
// Dequantize groups
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 8;
}
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 16;
}
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 32;
}
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 8;
}
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 32;
}
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
qs_h0 = scales[scales_idx][0];
qs_h1 = scales[scales_idx][1];
qs_h2 = scales[scales_idx][2];
qs_h3 = scales[scales_idx][3];
nextgroup += b_q_group_map[k * 2 + 1];
}
#pragma unroll
for (int j = 0; j < 2; j++)
for (int j = 0; j < 1; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
}
a_ptr += 16;
}
k += 32;
k += 16;
}
// Accumulate column sums in c
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
// *out = result01;
// *(out + 1) = result23;
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_exl2 {
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
{
#if EXL2_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
return NULL;
}

View File

@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return result;
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int,
const uint16_t*,
const int,
const bool
const bool,
const half*,
const int
);
template <bool first_block, int m_count>
template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
const bool clear,
const half* r_weights,
const int r_weights_stride
)
{
MatrixView_half a_(a, size_m, size_k);
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
int a_stride = GPTQ_BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
// Column result
float block_c[m_count][4] = {};
half2 block_c[m_count][4] = {};
// Dequantize and multiply
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
half2 result01 = __halves2half2(result0, result1);
half2 result23 = __halves2half2(result2, result3);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_gptq {
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
{
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
return NULL;
}

View File

@ -57,6 +57,7 @@ QMatrix::QMatrix
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
@ -80,13 +81,17 @@ QMatrix::QMatrix
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_q_group_map = _q_group_map;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
if (is_gptq)
{
gptq_groupsize = 1;
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
}
// Create group map
@ -102,15 +107,26 @@ QMatrix::QMatrix
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
int row = 0;
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
int rows;
if (i < groups - 1)
{
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
rows = qrows * 32 / bits;
}
else rows = height - row;
if (bits == 8) rows_8 += rows;
if (bits == 6) rows_6 += rows;
if (bits == 5) rows_5 += rows;
if (bits == 4) rows_4 += rows;
if (bits == 3) rows_3 += rows;
if (bits == 2) rows_2 += rows;
row += rows;
}
free(cpu_q_groups);
@ -138,6 +154,13 @@ QMatrix::QMatrix
}
}
// DBGI(rows_8);
// DBGI(rows_6);
// DBGI(rows_5);
// DBGI(rows_4);
// DBGI(rows_3);
// DBGI(rows_2);
// Shuffle quantized data
dim3 blockDim, gridDim;
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const uint16_t* __restrict__ b_q_group_map,
const int size_k,
const int size_n,
const int groupsize,
//const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
// Find initial group
int group = offset_k / groupsize;
// int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
k += 16;
}
}
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
cuda_q_group_map,
height,
width,
groupsize,
//groupsize,
groups,
out,
rows_8,
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
gptq_groupsize,
groups,
out,
rows_4

View File

@ -18,7 +18,7 @@ public:
int height;
int width;
int groups;
int groupsize;
int gptq_groupsize;
int rows_8;
int rows_6;
@ -33,6 +33,7 @@ public:
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint16_t* cuda_q_group_map = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
@ -53,6 +54,7 @@ public:
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros,
half* _gptq_scales,

View File

@ -7,6 +7,7 @@ union half2_uint32
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32() : as_uint32(0) {}
};
union half_uint16
@ -15,6 +16,7 @@ union half_uint16
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16() : as_uint16(0) {}
};
// Max_scale premultiplied by 1/256

View File

@ -1,3 +1,11 @@
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
if (abort) exit(code);
}
}
void print_global_mem(const half* ptr, int rows, int columns, int stride);
#endif

View File

@ -31,6 +31,7 @@ uintptr_t make_q_matrix
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor q_group_map,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
if (m->failed) throw std::runtime_error("CUDA out of memory");
return reinterpret_cast<uintptr_t> (m);
}

View File

@ -32,10 +32,10 @@ def fresh_cache():
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d
os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@ -47,7 +47,7 @@ def prefetched():
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"]
allow_patterns=["*.safetensors"],
)
yield model_id
@ -61,7 +61,15 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors']
root = None
assert len(filenames) == 1
for f in filenames:
curroot, filename = os.path.split(f)
if root is None:
root = curroot
else:
assert root == curroot
assert filename == "model.safetensors"
def test_weight_hub_files():

View File

@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params()
bits, groupsize, _ = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
"""
Create Q matrix
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return make_q_matrix(
w["q_weight"],
w["q_perm"],
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
none_tensor,
none_tensor,
none_tensor,
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,

View File

@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
return filenames
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
def _weight_hub_files_from_model_info(
info: hf_api.ModelInfo, extension: str
) -> List[str]:
return [
s.rfilename
for s in info.siblings
@ -44,20 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d)))
filenames = [f for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f]
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "adapter" not in f
and "training" not in f
]
return filenames
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:
if revision is None:
revision = "main"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
)
if not repo_cache.is_dir():
# No cache for this model
@ -85,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()

View File

@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True
try:
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning(
V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False

View File

@ -0,0 +1,6 @@
from functools import lru_cache
@lru_cache(10)
def log_once(log, msg:str):
log(msg)

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.utils.log import log_once
class Weights:
@ -161,7 +162,7 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
@ -211,10 +212,10 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -240,11 +241,15 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
if bits != 4:
use_exllama = False
if desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
@ -274,12 +279,18 @@ class Weights:
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
logger.warning(
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
)
use_exllama = False
else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
log_once(
logger.info,
f"Using exllama kernels v{HAS_EXLLAMA}"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
@ -288,14 +299,12 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama:
g_idx = g_idx - g_idx[0]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -314,18 +323,20 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int]:
def _get_gptq_params(self) -> Tuple[int, int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
except Exception:
raise e
return bits, groupsize
return bits, groupsize, desc_act
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
@ -340,6 +351,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
@ -353,6 +365,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
filename = "quant_config.json"
try:
@ -366,5 +379,6 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
pass