mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
b223ac70b6
@ -39,7 +39,7 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
|
||||
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
* Quantization (GPTQ, AWQ, etc.)
|
||||
@ -47,5 +47,5 @@ TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with pag
|
||||
* Kernel for slinding window attention (Mistral)
|
||||
|
||||
TGI is also supported on the following AI hardware accelerators:
|
||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
|
||||
|
@ -2,6 +2,7 @@
|
||||
#define _config_h
|
||||
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
|
||||
|
||||
#define QMODE_2BIT 1
|
||||
#define QMODE_3BIT 1
|
||||
@ -10,4 +11,5 @@
|
||||
#define QMODE_6BIT 0
|
||||
#define QMODE_8BIT 0
|
||||
|
||||
|
||||
#endif
|
||||
|
@ -10,16 +10,19 @@
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define GPTQ_BLOCK_KN_SIZE 128
|
||||
#define GPTQ_BLOCK_M_SIZE_MAX 8
|
||||
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
|
||||
|
||||
#define EXL2_BLOCK_KN_SIZE 64
|
||||
#define EXL2_BLOCK_M_SIZE_MAX 8
|
||||
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
|
||||
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
#include "compat_gemm.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
bool clear
|
||||
bool clear,
|
||||
const half* r_weights,
|
||||
int r_weights_stride,
|
||||
bool mul_r_weights
|
||||
)
|
||||
{
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_group_map,
|
||||
b->cuda_q_perm,
|
||||
b->rows_8,
|
||||
b->rows_6,
|
||||
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
|
||||
b->rows_4,
|
||||
b->rows_3,
|
||||
b->rows_2,
|
||||
clear
|
||||
clear,
|
||||
r_weights,
|
||||
r_weights_stride
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.x = GPTQ_BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||
|
||||
// DBGX((uint64_t) b->cuda_q_perm);
|
||||
// DBGI(b->rows_4);
|
||||
// DBGI(b->height);
|
||||
// DBGX((uint64_t) r_weights);
|
||||
// if (r_weights)
|
||||
// print_global_mem(r_weights, 1, 1, 1);
|
||||
// DBGI(r_weights_stride);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->gptq_groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_4,
|
||||
clear
|
||||
clear,
|
||||
r_weights,
|
||||
r_weights_stride
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
|
||||
int size_k,
|
||||
bool clear,
|
||||
half* temp_dq,
|
||||
bool force_cuda
|
||||
bool force_cuda,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride,
|
||||
bool mul_r_weights
|
||||
)
|
||||
{
|
||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||
{
|
||||
//printf("cublas\n");
|
||||
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
|
||||
if (!temp_dq) temp_dq = b->temp_dq;
|
||||
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasSgemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("cuda\n");
|
||||
|
||||
// Quantized matmul
|
||||
|
||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
||||
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
||||
int max_chunks = size_m / block_m_size_max;
|
||||
int last_chunk = max_chunks * block_m_size_max;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -201,11 +210,10 @@ void clear_tensor_cuda
|
||||
int size_n
|
||||
)
|
||||
{
|
||||
return;
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = CLEAR_N_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
gridDim.y = size_m;
|
||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
// dim3 blockDim, gridDim;
|
||||
// blockDim.x = CLEAR_N_SIZE;
|
||||
// blockDim.y = 1;
|
||||
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
// gridDim.y = size_m;
|
||||
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
}
|
||||
|
@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
|
||||
int size_k,
|
||||
bool clear = false,
|
||||
half* reconstruct = NULL,
|
||||
bool force_cuda = false
|
||||
bool force_cuda = false,
|
||||
const half* r_weights = NULL,
|
||||
const int r_weights_stride = 0,
|
||||
bool mul_r_weights = false
|
||||
);
|
||||
|
||||
void clear_tensor_cuda
|
||||
|
@ -1,8 +1,5 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
|
||||
|
||||
float result = {};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half2 w01 = dq[i];
|
||||
float w0 = __low2float(w01);
|
||||
float w1 = __high2float(w01);
|
||||
float x0 = __half2float(*a_ptr++);
|
||||
float x1 = __half2float(*a_ptr++);
|
||||
result = fma(w0, x0, result);
|
||||
result = fma(w1, x1, result);
|
||||
}
|
||||
float qs = __half2float(qs_h);
|
||||
result *= qs;
|
||||
half result_h = __float2half_rn(result);
|
||||
return __hadd(result_h, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||
return __hfma(result_h, qs_h, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||
return __hfma(result_h, qs_h, g_result);
|
||||
}
|
||||
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_kernel)
|
||||
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const int,
|
||||
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const bool
|
||||
const bool,
|
||||
const half*,
|
||||
const int
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||
__global__ void gemm_half_q_half_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_group_map,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2,
|
||||
const bool clear
|
||||
const bool clear,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Read weights
|
||||
|
||||
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||
if constexpr (use_r_weights)
|
||||
{
|
||||
uint16_t any_w = 0;
|
||||
const half* w_ptr = r_weights;
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
weights[m].as_half = *w_ptr;
|
||||
w_ptr += r_weights_stride;
|
||||
any_w |= weights[m].as_uint16;
|
||||
}
|
||||
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||
}
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
// half a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
//int group = offset_k / groupsize;
|
||||
int group = b_q_group_map[offset_k * 2];
|
||||
|
||||
// if (offset_m == 0 && t == 0)
|
||||
// DBGI2(offset_k, group);
|
||||
|
||||
// Preload scales
|
||||
|
||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
||||
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
|
||||
|
||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
for (int g = 0; g < groups_in_block; g++)
|
||||
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
int temp_k = offset_k;
|
||||
for (int g = 0; temp_k < end_k; g++)
|
||||
{
|
||||
int qscales[4];
|
||||
b_q_scale_.item4(qscales, group + g, n);
|
||||
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
|
||||
qscales[1]++;
|
||||
qscales[2]++;
|
||||
qscales[3]++;
|
||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
||||
half maxscale = b_q_scale_max[group + g];
|
||||
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
|
||||
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
|
||||
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
|
||||
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
|
||||
temp_k += b_q_group_map[temp_k * 2 + 1];
|
||||
}
|
||||
|
||||
// a, b offset
|
||||
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
int a_stride = EXL2_BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int scales_idx = 0;
|
||||
float qs_f0 = scales[scales_idx][0];
|
||||
float qs_f1 = scales[scales_idx][1];
|
||||
float qs_f2 = scales[scales_idx][2];
|
||||
float qs_f3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + groupsize;
|
||||
half qs_h0 = scales[scales_idx][0];
|
||||
half qs_h1 = scales[scales_idx][1];
|
||||
half qs_h2 = scales[scales_idx][2];
|
||||
half qs_h3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
half block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize groups
|
||||
|
||||
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 16;
|
||||
}
|
||||
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
k += 16;
|
||||
}
|
||||
|
||||
// Accumulate column sums in c
|
||||
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
|
||||
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
|
||||
|
||||
if constexpr (mul_r_weights)
|
||||
{
|
||||
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||
result01 = __hmul2(result01, w_mul2);
|
||||
result23 = __hmul2(result23, w_mul2);
|
||||
}
|
||||
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
// *out = result01;
|
||||
// *(out + 1) = result23;
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
||||
template <bool use_r_weights, bool mul_r_weights>
|
||||
struct map_m_count_exl2 {
|
||||
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
|
||||
{
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
};
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
||||
#endif
|
||||
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const bool
|
||||
const bool,
|
||||
const half*,
|
||||
const int
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_4,
|
||||
const bool clear
|
||||
const bool clear,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Read weights
|
||||
|
||||
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||
if constexpr (use_r_weights)
|
||||
{
|
||||
uint16_t any_w = 0;
|
||||
const half* w_ptr = r_weights;
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
weights[m].as_half = *w_ptr;
|
||||
w_ptr += r_weights_stride;
|
||||
any_w |= weights[m].as_uint16;
|
||||
}
|
||||
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||
}
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
int a_stride = GPTQ_BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
half2 block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
|
||||
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
|
||||
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
|
||||
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
|
||||
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
|
||||
half2 result01 = __halves2half2(result0, result1);
|
||||
half2 result23 = __halves2half2(result2, result3);
|
||||
|
||||
if constexpr (mul_r_weights)
|
||||
{
|
||||
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||
result01 = __hmul2(result01, w_mul2);
|
||||
result23 = __hmul2(result23, w_mul2);
|
||||
}
|
||||
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
template <bool use_r_weights, bool mul_r_weights>
|
||||
struct map_m_count_gptq {
|
||||
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
|
||||
{
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
};
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -57,6 +57,7 @@ QMatrix::QMatrix
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
uint16_t* _q_group_map,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
@ -80,13 +81,17 @@ QMatrix::QMatrix
|
||||
cuda_q_scale = _q_scale;
|
||||
cuda_q_scale_max = _q_scale_max;
|
||||
cuda_q_groups = _q_groups;
|
||||
cuda_q_group_map = _q_group_map;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
groupsize = 1;
|
||||
while (groupsize * groups < height) groupsize *= 2;
|
||||
if (is_gptq)
|
||||
{
|
||||
gptq_groupsize = 1;
|
||||
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
|
||||
}
|
||||
|
||||
// Create group map
|
||||
|
||||
@ -102,15 +107,26 @@ QMatrix::QMatrix
|
||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
int row = 0;
|
||||
for (int i = 0; i < groups; i++)
|
||||
{
|
||||
int bits = cpu_q_groups[i * 2];
|
||||
if (bits == 8) rows_8 += groupsize;
|
||||
if (bits == 6) rows_6 += groupsize;
|
||||
if (bits == 5) rows_5 += groupsize;
|
||||
if (bits == 4) rows_4 += groupsize;
|
||||
if (bits == 3) rows_3 += groupsize;
|
||||
if (bits == 2) rows_2 += groupsize;
|
||||
|
||||
int rows;
|
||||
if (i < groups - 1)
|
||||
{
|
||||
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
|
||||
rows = qrows * 32 / bits;
|
||||
}
|
||||
else rows = height - row;
|
||||
|
||||
if (bits == 8) rows_8 += rows;
|
||||
if (bits == 6) rows_6 += rows;
|
||||
if (bits == 5) rows_5 += rows;
|
||||
if (bits == 4) rows_4 += rows;
|
||||
if (bits == 3) rows_3 += rows;
|
||||
if (bits == 2) rows_2 += rows;
|
||||
row += rows;
|
||||
}
|
||||
|
||||
free(cpu_q_groups);
|
||||
@ -138,6 +154,13 @@ QMatrix::QMatrix
|
||||
}
|
||||
}
|
||||
|
||||
// DBGI(rows_8);
|
||||
// DBGI(rows_6);
|
||||
// DBGI(rows_5);
|
||||
// DBGI(rows_4);
|
||||
// DBGI(rows_3);
|
||||
// DBGI(rows_2);
|
||||
|
||||
// Shuffle quantized data
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const uint16_t* __restrict__ b_q_group_map,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
//const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_8,
|
||||
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
// int group = offset_k / groupsize;
|
||||
int group = b_q_group_map[offset_k * 2];
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||
int nextgroup = offset_k + groupsize;
|
||||
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int k = offset_k;
|
||||
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
k += 16;
|
||||
}
|
||||
}
|
||||
|
||||
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
|
||||
cuda_q_perm,
|
||||
cuda_q_scale,
|
||||
cuda_q_scale_max,
|
||||
//cuda_q_groups,
|
||||
cuda_q_group_map,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
//groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_8,
|
||||
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
gptq_groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_4
|
||||
|
@ -18,7 +18,7 @@ public:
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
int gptq_groupsize;
|
||||
|
||||
int rows_8;
|
||||
int rows_6;
|
||||
@ -33,6 +33,7 @@ public:
|
||||
uint32_t* cuda_q_scale = NULL;
|
||||
half* cuda_q_scale_max = NULL;
|
||||
uint16_t* cuda_q_groups = NULL;
|
||||
uint16_t* cuda_q_group_map = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
|
||||
@ -53,6 +54,7 @@ public:
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
uint16_t* _q_group_map,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
|
@ -7,6 +7,7 @@ union half2_uint32
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
__device__ half2_uint32() : as_uint32(0) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
@ -15,6 +16,7 @@ union half_uint16
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
__device__ half_uint16() : as_uint16(0) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
@ -1,3 +1,11 @@
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
||||
|
||||
void print_global_mem(const half* ptr, int rows, int columns, int stride);
|
||||
|
||||
#endif
|
@ -31,6 +31,7 @@ uintptr_t make_q_matrix
|
||||
torch::Tensor q_scale,
|
||||
torch::Tensor q_scale_max,
|
||||
torch::Tensor q_groups,
|
||||
torch::Tensor q_group_map,
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
|
||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
@ -32,10 +32,10 @@ def fresh_cache():
|
||||
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = d
|
||||
os.environ["HUGGINGFACE_HUB_CACHE"] = d
|
||||
yield
|
||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
|
||||
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
|
||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def prefetched():
|
||||
revision="main",
|
||||
local_files_only=False,
|
||||
repo_type="model",
|
||||
allow_patterns=["*.safetensors"]
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
yield model_id
|
||||
|
||||
@ -61,7 +61,15 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
|
||||
def test_weight_hub_files_offline_ok(prefetched, offline):
|
||||
# If the model is prefetched then we should be able to get the weight files from local cache
|
||||
filenames = weight_hub_files(prefetched)
|
||||
assert filenames == ['model.safetensors']
|
||||
root = None
|
||||
assert len(filenames) == 1
|
||||
for f in filenames:
|
||||
curroot, filename = os.path.split(f)
|
||||
if root is None:
|
||||
root = curroot
|
||||
else:
|
||||
assert root == curroot
|
||||
assert filename == "model.safetensors"
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
|
@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
bits, groupsize = weights._get_gptq_params()
|
||||
bits, groupsize, _ = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
|
@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
return output.view(output_shape)
|
||||
|
||||
|
||||
# Group map needed for irregular group sizes
|
||||
|
||||
|
||||
def make_group_map(q_groups, num_qrows):
|
||||
|
||||
gr = q_groups.tolist()
|
||||
group_map = []
|
||||
num_groups = len(gr) // 2
|
||||
|
||||
for i in range(num_groups):
|
||||
bits = gr[i * 2]
|
||||
if i < num_groups - 1:
|
||||
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
||||
else:
|
||||
qrows = num_qrows - gr[i * 2 + 1]
|
||||
rows = qrows * 32 // bits
|
||||
for j in range(rows):
|
||||
group_map += [i]
|
||||
group_map += [rows - j]
|
||||
|
||||
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
||||
|
||||
|
||||
# Create Q matrix
|
||||
|
||||
|
||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
"""
|
||||
Create Q matrix
|
||||
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
w["q_scale_max"] /= 256
|
||||
w["q_perm"] = w["q_perm"].short()
|
||||
w["q_invperm"] = w["q_invperm"].short()
|
||||
|
||||
if "q_group_map" not in w:
|
||||
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
||||
|
||||
return make_q_matrix(
|
||||
w["q_weight"],
|
||||
w["q_perm"],
|
||||
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
w["q_group_map"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
none_tensor,
|
||||
|
@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||
|
||||
|
||||
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
|
||||
def _cached_weight_files(
|
||||
model_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
"""Guess weight files from the cached revision snapshot directory"""
|
||||
d = _get_cached_revision_directory(model_id, revision)
|
||||
if not d:
|
||||
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
|
||||
return filenames
|
||||
|
||||
|
||||
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
|
||||
def _weight_hub_files_from_model_info(
|
||||
info: hf_api.ModelInfo, extension: str
|
||||
) -> List[str]:
|
||||
return [
|
||||
s.rfilename
|
||||
for s in info.siblings
|
||||
@ -44,20 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
# see _weight_hub_files_from_model_info, that's also what is
|
||||
# done there with the len(s.rfilename.split("/")) == 1 condition
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [f for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "training" not in f]
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "adapter" not in f
|
||||
and "training" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
|
||||
def _get_cached_revision_directory(
|
||||
model_id: str, revision: Optional[str]
|
||||
) -> Optional[Path]:
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
||||
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
|
||||
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
|
||||
)
|
||||
|
||||
if not repo_cache.is_dir():
|
||||
# No cache for this model
|
||||
@ -85,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
|
||||
|
||||
|
||||
def weight_hub_files(
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[str]:
|
||||
"""Get the weights filenames on the hub"""
|
||||
api = HfApi()
|
||||
|
@ -19,6 +19,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
logger.warning(
|
||||
V2 = False
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
||||
)
|
||||
V2 = False
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
|
6
server/text_generation_server/utils/log.py
Normal file
6
server/text_generation_server/utils/log.py
Normal file
@ -0,0 +1,6 @@
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def log_once(log, msg:str):
|
||||
log(msg)
|
@ -6,6 +6,7 @@ import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
class Weights:
|
||||
@ -161,7 +162,7 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
@ -211,10 +212,10 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -240,11 +241,15 @@ class Weights:
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
@ -274,12 +279,18 @@ class Weights:
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
logger.warning(
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
log_once(
|
||||
logger.info,
|
||||
f"Using exllama kernels v{HAS_EXLLAMA}"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if use_exllama and groupsize != -1:
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
@ -288,14 +299,12 @@ class Weights:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if use_exllama:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
@ -314,18 +323,20 @@ class Weights:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
desc_act = False
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
return bits, groupsize, desc_act
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
@ -340,6 +351,7 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
@ -353,6 +365,7 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
@ -366,5 +379,6 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user