diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 34775139..0708c729 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -39,7 +39,7 @@ text-generation-launcher --model-id ## Supported Hardware -TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. +TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: * Quantization (GPTQ, AWQ, etc.) @@ -47,5 +47,5 @@ TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with pag * Kernel for slinding window attention (Mistral) TGI is also supported on the following AI hardware accelerators: -- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index) +- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index) * *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2. diff --git a/server/exllamav2_kernels/exllamav2_kernels/config.h b/server/exllamav2_kernels/exllamav2_kernels/config.h index 86baaf41..32a1a37d 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/config.h +++ b/server/exllamav2_kernels/exllamav2_kernels/config.h @@ -2,6 +2,7 @@ #define _config_h #define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS #define QMODE_2BIT 1 #define QMODE_3BIT 1 @@ -10,4 +11,5 @@ #define QMODE_6BIT 0 #define QMODE_8BIT 0 + #endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index 351b9cd5..b4e4cf22 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -10,16 +10,19 @@ #include "quant/qdq_6.cuh" #include "quant/qdq_8.cuh" -#define BLOCK_KN_SIZE 128 -#define BLOCK_M_SIZE_MAX 8 -#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define GPTQ_BLOCK_KN_SIZE 128 +#define GPTQ_BLOCK_M_SIZE_MAX 8 +#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32) + +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + #define CLEAR_N_SIZE 256 #include "q_gemm_kernel.cuh" #include "q_gemm_kernel_gptq.cuh" -#include "compat_gemm.cuh" - void gemm_half_q_half_cuda_part ( const half* a, @@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part int size_n, int size_k, int m_count, - bool clear + bool clear, + const half* r_weights, + int r_weights_stride, + bool mul_r_weights ) { if (!b->is_gptq) { dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; + blockDim.x = EXL2_BLOCK_KN_SIZE; blockDim.y = 1; blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE); - fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count); + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); kernel<<>> ( @@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part size_n, size_k, b->groups, - b->groupsize, + b->cuda_q_group_map, b->cuda_q_perm, b->rows_8, b->rows_6, @@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part b->rows_4, b->rows_3, b->rows_2, - clear + clear, + r_weights, + r_weights_stride ); } else { dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; + blockDim.x = GPTQ_BLOCK_KN_SIZE; blockDim.y = 1; blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4); gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights); -// DBGX((uint64_t) b->cuda_q_perm); -// DBGI(b->rows_4); -// DBGI(b->height); +// DBGX((uint64_t) r_weights); +// if (r_weights) +// print_global_mem(r_weights, 1, 1, 1); +// DBGI(r_weights_stride); kernel<<>> ( @@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part size_n, size_k, b->groups, - b->groupsize, + b->gptq_groupsize, b->cuda_q_perm, b->rows_4, - clear + clear, + r_weights, + r_weights_stride ); } } @@ -112,13 +123,14 @@ void gemm_half_q_half_cuda int size_k, bool clear, half* temp_dq, - bool force_cuda + bool force_cuda, + const half* r_weights, + const int r_weights_stride, + bool mul_r_weights ) { if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) { - //printf("cublas\n"); - // Reconstruct FP16 matrix, then cuBLAS if (!temp_dq) temp_dq = b->temp_dq; @@ -139,12 +151,12 @@ void gemm_half_q_half_cuda //const float alpha = 1.0f; //const float beta = clear ? 0.0f : 1.0f; //cublasSgemmEx(cublas_handle, - // CUBLAS_OP_N, - // CUBLAS_OP_N, - // size_n, size_m, size_k, - // &alpha, temp_dq, CUDA_R_16F, size_n, - // a, CUDA_R_16F, size_k, - // &beta, c, CUDA_R_16F, size_n); + // CUBLAS_OP_N, + // CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n); //const float alpha = 1.0f; //const float beta = clear ? 0.0f : 1.0f; @@ -158,24 +170,21 @@ void gemm_half_q_half_cuda } else { - //printf("cuda\n"); - // Quantized matmul - //if (clear) clear_tensor_cuda(c, size_m, size_n); - - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; int last_chunk_size = size_m - last_chunk; if (max_chunks) { - gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights); } if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights); } } } @@ -201,11 +210,10 @@ void clear_tensor_cuda int size_n ) { - return; - dim3 blockDim, gridDim; - blockDim.x = CLEAR_N_SIZE; - blockDim.y = 1; - gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); - gridDim.y = size_m; - clear_kernel<<>>(c, size_m, size_n); +// dim3 blockDim, gridDim; +// blockDim.x = CLEAR_N_SIZE; +// blockDim.y = 1; +// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); +// gridDim.y = size_m; +// clear_kernel<<>>(c, size_m, size_n); } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh index c69f1a70..b643f915 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -20,7 +20,10 @@ void gemm_half_q_half_cuda int size_k, bool clear = false, half* reconstruct = NULL, - bool force_cuda = false + bool force_cuda = false, + const half* r_weights = NULL, + const int r_weights_stride = 0, + bool mul_r_weights = false ); void clear_tensor_cuda diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh index 0b899a84..9cd2ba01 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -1,8 +1,5 @@ #include "compat.cuh" -#include -#include - __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; @@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c return fma(result_f, qs_f, g_result); } +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} typedef void (*fp_gemm_half_q_half_kernel) @@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel) const int, const int, const int, - const int, + const uint16_t*, const uint16_t*, const int, const int, @@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel) const int, const int, const int, - const bool + const bool, + const half*, + const int ); -template +template __global__ void gemm_half_q_half_kernel ( const half* __restrict__ a, @@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel const int size_n, const int size_k, const int groups, - const int groupsize, + const uint16_t* __restrict__ b_q_group_map, const uint16_t* __restrict__ b_q_perm, const int rows_8, const int rows_6, @@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel const int rows_4, const int rows_3, const int rows_2, - const bool clear + const bool clear, + const half* r_weights, + const int r_weights_stride ) { MatrixView_half a_(a, size_m, size_k); @@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k); int n = offset_n + t * 4; + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; if (offset_k + t < end_k) { @@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel const half* a_ptr = a_.item_ptr(offset_m + m, 0); half* block_a_ptr = block_a[m]; half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; block_a_ptr[t] = a0; } } @@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel // Find initial group - int group = offset_k / groupsize; + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); // Preload scales - float scales[MAX_GROUPS_IN_BLOCK][4]; + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; - int groups_in_block = DIVIDE((end_k - offset_k), groupsize); - for (int g = 0; g < groups_in_block; g++) + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) { int qscales[4]; b_q_scale_.item4(qscales, group + g, n); @@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel qscales[1]++; qscales[2]++; qscales[3]++; - float maxscale = __half2float(b_q_scale_max[group + g]); - scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale; - scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale; - scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale; - scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; } // a, b offset @@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; + int a_stride = EXL2_BLOCK_KN_SIZE; // Initial group int scales_idx = 0; - float qs_f0 = scales[scales_idx][0]; - float qs_f1 = scales[scales_idx][1]; - float qs_f2 = scales[scales_idx][2]; - float qs_f3 = scales[scales_idx][3]; - int nextgroup = offset_k + groupsize; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; // Column result - float block_c[m_count][4] = {}; + half block_c[m_count][4] = {}; // Dequantize groups @@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 8; } @@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 16; } @@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 32; } @@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 8; } @@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll @@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 32; } @@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel { group++; scales_idx++; - qs_f0 = scales[scales_idx][0]; - qs_f1 = scales[scales_idx][1]; - qs_f2 = scales[scales_idx][2]; - qs_f3 = scales[scales_idx][3]; - nextgroup += groupsize; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; } #pragma unroll - for (int j = 0; j < 2; j++) + for (int j = 0; j < 1; j++) { int4 load_int4[1]; load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; @@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { - block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0); - block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1); - block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2); - block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); } a_ptr += 16; } - k += 32; + k += 16; } // Accumulate column sums in c @@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel for (int m = 0; m < m_count; m++) { half2* out = (half2*)c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + atomicAdd(out , result01); atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; } } -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count) +template +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights) { - #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_kernel; - #endif + if (!r_weights && !mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count); return NULL; } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh index ebaa42d0..74b0db2b 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) return __half2float(__low2half(result)) + __half2float(__high2half(result)); } +__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return result; +} + typedef void (*fp_gemm_half_q_half_gptq_kernel) ( const half*, @@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel) const int, const uint16_t*, const int, - const bool + const bool, + const half*, + const int ); -template +template __global__ void gemm_half_q_half_gptq_kernel ( const half* __restrict__ a, @@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel const int groupsize, const uint16_t* __restrict__ b_q_perm, const int rows_4, - const bool clear + const bool clear, + const half* r_weights, + const int r_weights_stride ) { MatrixView_half a_(a, size_m, size_k); @@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4; int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; + int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE; - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n); int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k); int n = offset_n + t * 4; + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE]; if (offset_k + t < end_k) { @@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; + int a_stride = GPTQ_BLOCK_KN_SIZE; // Initial group int zeros[4]; - float scales[4]; + half2 scales[4]; half2 z1z16[4][2]; half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); + b_gptq_scales_.item4_h2(scales, group, n); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); @@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel // Column result - float block_c[m_count][4] = {}; + half2 block_c[m_count][4] = {}; // Dequantize and multiply @@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel group++; nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); + b_gptq_scales_.item4_h2(scales, group, n); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); @@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel #pragma unroll for (int m = 0; m < m_count; m++) { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); } b_ptr += size_n; @@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel for (int m = 0; m < m_count; m++) { half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0])); + half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1])); + half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2])); + half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3])); + half2 result01 = __halves2half2(result0, result1); + half2 result23 = __halves2half2(result2, result3); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + atomicAdd(out , result01); atomicAdd(out + 1, result23); } } -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +template +struct map_m_count_gptq { + static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) + { + #if GPTQ_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) { - #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_gptq_kernel; - #endif + if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); return NULL; } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index 6aed7470..ae08cc1f 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -57,6 +57,7 @@ QMatrix::QMatrix uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map, uint32_t* _gptq_qzeros, half* _gptq_scales, @@ -80,13 +81,17 @@ QMatrix::QMatrix cuda_q_scale = _q_scale; cuda_q_scale_max = _q_scale_max; cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; cuda_gptq_qzeros = _gptq_qzeros; cuda_gptq_scales = _gptq_scales; is_gptq = (_gptq_qzeros != NULL); - groupsize = 1; - while (groupsize * groups < height) groupsize *= 2; + if (is_gptq) + { + gptq_groupsize = 1; + while (gptq_groupsize * groups < height) gptq_groupsize *= 2; + } // Create group map @@ -102,15 +107,26 @@ QMatrix::QMatrix uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + int row = 0; for (int i = 0; i < groups; i++) { int bits = cpu_q_groups[i * 2]; - if (bits == 8) rows_8 += groupsize; - if (bits == 6) rows_6 += groupsize; - if (bits == 5) rows_5 += groupsize; - if (bits == 4) rows_4 += groupsize; - if (bits == 3) rows_3 += groupsize; - if (bits == 2) rows_2 += groupsize; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else rows = height - row; + + if (bits == 8) rows_8 += rows; + if (bits == 6) rows_6 += rows; + if (bits == 5) rows_5 += rows; + if (bits == 4) rows_4 += rows; + if (bits == 3) rows_3 += rows; + if (bits == 2) rows_2 += rows; + row += rows; } free(cpu_q_groups); @@ -138,6 +154,13 @@ QMatrix::QMatrix } } +// DBGI(rows_8); +// DBGI(rows_6); +// DBGI(rows_5); +// DBGI(rows_4); +// DBGI(rows_3); +// DBGI(rows_2); + // Shuffle quantized data dim3 blockDim, gridDim; @@ -283,10 +306,10 @@ __global__ void reconstruct_kernel const uint16_t* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_scale, const half* __restrict__ b_q_scale_max, - //const uint16_t* __restrict__ b_q_groups, + const uint16_t* __restrict__ b_q_group_map, const int size_k, const int size_n, - const int groupsize, + //const int groupsize, const int groups, half* __restrict__ b, const int rows_8, @@ -317,7 +340,8 @@ __global__ void reconstruct_kernel // Find initial group - int group = offset_k / groupsize; + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; int pre_rows_8 = min(rows_8, offset_k); int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; @@ -337,7 +361,7 @@ __global__ void reconstruct_kernel half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); half2 qs_h2 = __halves2half2(qs_h, qs_h); - int nextgroup = offset_k + groupsize; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int k = offset_k; @@ -347,7 +371,7 @@ __global__ void reconstruct_kernel while (k < rows_8 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 4; p++) { half2 dq[4]; @@ -363,7 +387,7 @@ __global__ void reconstruct_kernel while (k < rows_6 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 2; p++) { half2 dq[8]; @@ -380,7 +404,7 @@ __global__ void reconstruct_kernel while (k < rows_5 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 1; p++) { half2 dq[16]; @@ -399,7 +423,7 @@ __global__ void reconstruct_kernel while (k < rows_4 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 4; p++) { half2 dq[4]; @@ -414,7 +438,7 @@ __global__ void reconstruct_kernel while (k < rows_3 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } for (int p = 0; p < 1; p++) { half2 dq[16]; @@ -431,8 +455,8 @@ __global__ void reconstruct_kernel while (k < rows_2 && k < end_k) { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 2; p++) + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) { half2 dq[8]; uint32_t q_0 = *b_ptr; b_ptr += size_n; @@ -441,7 +465,7 @@ __global__ void reconstruct_kernel half* dqh = (half*) dq; for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); } - k += 32; + k += 16; } } @@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out) cuda_q_perm, cuda_q_scale, cuda_q_scale_max, - //cuda_q_groups, + cuda_q_group_map, height, width, - groupsize, + //groupsize, groups, out, rows_8, @@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out) //const uint16_t* __restrict__ b_q_groups, height, width, - groupsize, + gptq_groupsize, groups, out, rows_4 diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh index dda83a4f..d36b8d66 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -18,7 +18,7 @@ public: int height; int width; int groups; - int groupsize; + int gptq_groupsize; int rows_8; int rows_6; @@ -33,6 +33,7 @@ public: uint32_t* cuda_q_scale = NULL; half* cuda_q_scale_max = NULL; uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; uint32_t* cuda_gptq_qzeros = NULL; half* cuda_gptq_scales = NULL; @@ -53,6 +54,7 @@ public: uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map, uint32_t* _gptq_qzeros, half* _gptq_scales, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh index 71657191..cac9df9c 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh @@ -7,6 +7,7 @@ union half2_uint32 half2 as_half2; __device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} }; union half_uint16 @@ -15,6 +16,7 @@ union half_uint16 half as_half; __device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} }; // Max_scale premultiplied by 1/256 diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh index 06a58d18..f56eda79 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -1,3 +1,11 @@ +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include +#include #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) @@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort= if (abort) exit(code); } } + +void print_global_mem(const half* ptr, int rows, int columns, int stride); + +#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp index 5e52e6ab..ff4e1851 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp +++ b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp @@ -31,6 +31,7 @@ uintptr_t make_q_matrix torch::Tensor q_scale, torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map, torch::Tensor gptq_qzeros, torch::Tensor gptq_scales, torch::Tensor gptq_g_idx, @@ -43,6 +44,7 @@ uintptr_t make_q_matrix TORCH_CHECK_DTYPE_OPT(q_scale, kInt); TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); TORCH_CHECK_DTYPE_OPT(q_groups, kShort); + TORCH_CHECK_DTYPE_OPT(q_group_map, kShort); TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); @@ -83,12 +85,15 @@ uintptr_t make_q_matrix q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), + q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(), gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), (half*) temp_dq.data_ptr() ); + if (m->failed) throw std::runtime_error("CUDA out of memory"); + return reinterpret_cast (m); } diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 5438c153..721820f5 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -32,10 +32,10 @@ def fresh_cache(): current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d - os.environ['HUGGINGFACE_HUB_CACHE'] = d + os.environ["HUGGINGFACE_HUB_CACHE"] = d yield huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value - os.environ['HUGGINGFACE_HUB_CACHE'] = current_value + os.environ["HUGGINGFACE_HUB_CACHE"] = current_value text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value @@ -47,7 +47,7 @@ def prefetched(): revision="main", local_files_only=False, repo_type="model", - allow_patterns=["*.safetensors"] + allow_patterns=["*.safetensors"], ) yield model_id @@ -61,7 +61,15 @@ def test_weight_hub_files_offline_error(offline, fresh_cache): def test_weight_hub_files_offline_ok(prefetched, offline): # If the model is prefetched then we should be able to get the weight files from local cache filenames = weight_hub_files(prefetched) - assert filenames == ['model.safetensors'] + root = None + assert len(filenames) == 1 + for f in filenames: + curroot, filename = os.path.split(f) + if root is None: + root = curroot + else: + assert root == curroot + assert filename == "model.safetensors" def test_weight_hub_files(): diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index cd93d32a..22d03adf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -71,7 +71,7 @@ def _load_multi_mqa_gptq( g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - bits, groupsize = weights._get_gptq_params() + bits, groupsize, _ = weights._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index dd41b269..a24e834b 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): return output.view(output_shape) +# Group map needed for irregular group sizes + + +def make_group_map(q_groups, num_qrows): + + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 + + for i in range(num_groups): + bits = gr[i * 2] + if i < num_groups - 1: + qrows = gr[i * 2 + 3] - gr[i * 2 + 1] + else: + qrows = num_qrows - gr[i * 2 + 1] + rows = qrows * 32 // bits + for j in range(rows): + group_map += [i] + group_map += [rows - j] + + return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) + + +# Create Q matrix + + def ext_make_q_matrix(w: dict, temp_dq, key: str = None): """ Create Q matrix @@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale_max"] /= 256 w["q_perm"] = w["q_perm"].short() w["q_invperm"] = w["q_invperm"].short() + + if "q_group_map" not in w: + w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + return make_q_matrix( w["q_weight"], w["q_perm"], @@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale"], w["q_scale_max"], w["q_groups"], + w["q_group_map"], none_tensor, none_tensor, none_tensor, @@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, none_tensor, none_tensor, + none_tensor, w["qzeros"], w["scales"], w["g_idx"].cpu(), @@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, none_tensor, none_tensor, + none_tensor, w["qzeros"], w["scales"], none_tensor, diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 019d4855..b56484f6 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] -def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: +def _cached_weight_files( + model_id: str, revision: Optional[str], extension: str +) -> List[str]: """Guess weight files from the cached revision snapshot directory""" d = _get_cached_revision_directory(model_id, revision) if not d: @@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) return filenames -def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: +def _weight_hub_files_from_model_info( + info: hf_api.ModelInfo, extension: str +) -> List[str]: return [ s.rfilename for s in info.siblings @@ -44,20 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: # see _weight_hub_files_from_model_info, that's also what is # done there with the len(s.rfilename.split("/")) == 1 condition root, _, files = next(os.walk(str(d))) - filenames = [f for f in files - if f.endswith(extension) - and "arguments" not in f - and "args" not in f - and "training" not in f] + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "adapter" not in f + and "training" not in f + ] return filenames -def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: +def _get_cached_revision_directory( + model_id: str, revision: Optional[str] +) -> Optional[Path]: if revision is None: revision = "main" repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( - file_download.repo_folder_name(repo_id=model_id, repo_type="model")) + file_download.repo_folder_name(repo_id=model_id, repo_type="model") + ) if not repo_cache.is_dir(): # No cache for this model @@ -85,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[str]: """Get the weights filenames on the hub""" api = HfApi() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 011a9382..6648b55a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -19,6 +19,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.log import log_once HAS_AWQ = True try: @@ -35,10 +36,11 @@ HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: - logger.warning( + V2 = False + log_once( + logger.warning, "Disabling exllama v2 and using v1 instead because there are issues when sharding" ) - V2 = False if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py new file mode 100644 index 00000000..d831fa76 --- /dev/null +++ b/server/text_generation_server/utils/log.py @@ -0,0 +1,6 @@ +from functools import lru_cache + + +@lru_cache(10) +def log_once(log, msg:str): + log(msg) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index a2cca2ea..ee1899ab 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -6,6 +6,7 @@ import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.utils.log import log_once class Weights: @@ -161,7 +162,7 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -211,10 +212,10 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA - use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" + use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -240,11 +241,15 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() if bits != 4: use_exllama = False + if desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: @@ -274,12 +279,18 @@ class Weights: if use_exllama: if not HAS_EXLLAMA: if CAN_EXLLAMA: - logger.warning( + log_once( + logger.warning, "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" ) use_exllama = False else: - logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") + log_once( + logger.info, + f"Using exllama kernels v{HAS_EXLLAMA}" + ) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) @@ -288,14 +299,12 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if use_exllama: g_idx = g_idx - g_idx[0] weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -314,18 +323,20 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + desc_act = False except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + desc_act = getattr(self, "gptq_desc_act", False) except Exception: raise e - return bits, groupsize + return bits, groupsize, desc_act def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -340,6 +351,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: @@ -353,6 +365,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -366,5 +379,6 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: pass