mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
b223ac70b6
@ -39,7 +39,7 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
|||||||
|
|
||||||
## Supported Hardware
|
## Supported Hardware
|
||||||
|
|
||||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||||
|
|
||||||
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
* Quantization (GPTQ, AWQ, etc.)
|
* Quantization (GPTQ, AWQ, etc.)
|
||||||
@ -47,5 +47,5 @@ TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with pag
|
|||||||
* Kernel for slinding window attention (Mistral)
|
* Kernel for slinding window attention (Mistral)
|
||||||
|
|
||||||
TGI is also supported on the following AI hardware accelerators:
|
TGI is also supported on the following AI hardware accelerators:
|
||||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||||
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
|
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#define _config_h
|
#define _config_h
|
||||||
|
|
||||||
#define MAX_Q_GEMM_ROWS 50
|
#define MAX_Q_GEMM_ROWS 50
|
||||||
|
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
|
||||||
|
|
||||||
#define QMODE_2BIT 1
|
#define QMODE_2BIT 1
|
||||||
#define QMODE_3BIT 1
|
#define QMODE_3BIT 1
|
||||||
@ -10,4 +11,5 @@
|
|||||||
#define QMODE_6BIT 0
|
#define QMODE_6BIT 0
|
||||||
#define QMODE_8BIT 0
|
#define QMODE_8BIT 0
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -10,16 +10,19 @@
|
|||||||
#include "quant/qdq_6.cuh"
|
#include "quant/qdq_6.cuh"
|
||||||
#include "quant/qdq_8.cuh"
|
#include "quant/qdq_8.cuh"
|
||||||
|
|
||||||
#define BLOCK_KN_SIZE 128
|
#define GPTQ_BLOCK_KN_SIZE 128
|
||||||
#define BLOCK_M_SIZE_MAX 8
|
#define GPTQ_BLOCK_M_SIZE_MAX 8
|
||||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
|
||||||
|
|
||||||
|
#define EXL2_BLOCK_KN_SIZE 64
|
||||||
|
#define EXL2_BLOCK_M_SIZE_MAX 8
|
||||||
|
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
|
||||||
|
|
||||||
#define CLEAR_N_SIZE 256
|
#define CLEAR_N_SIZE 256
|
||||||
|
|
||||||
#include "q_gemm_kernel.cuh"
|
#include "q_gemm_kernel.cuh"
|
||||||
#include "q_gemm_kernel_gptq.cuh"
|
#include "q_gemm_kernel_gptq.cuh"
|
||||||
|
|
||||||
#include "compat_gemm.cuh"
|
|
||||||
|
|
||||||
void gemm_half_q_half_cuda_part
|
void gemm_half_q_half_cuda_part
|
||||||
(
|
(
|
||||||
const half* a,
|
const half* a,
|
||||||
@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
|
|||||||
int size_n,
|
int size_n,
|
||||||
int size_k,
|
int size_k,
|
||||||
int m_count,
|
int m_count,
|
||||||
bool clear
|
bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
int r_weights_stride,
|
||||||
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if (!b->is_gptq)
|
if (!b->is_gptq)
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
blockDim.z = 1;
|
blockDim.z = 1;
|
||||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
|
||||||
gridDim.y = DIVIDE(size_m, m_count);
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
|
||||||
|
|
||||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
size_n,
|
size_n,
|
||||||
size_k,
|
size_k,
|
||||||
b->groups,
|
b->groups,
|
||||||
b->groupsize,
|
b->cuda_q_group_map,
|
||||||
b->cuda_q_perm,
|
b->cuda_q_perm,
|
||||||
b->rows_8,
|
b->rows_8,
|
||||||
b->rows_6,
|
b->rows_6,
|
||||||
@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
|
|||||||
b->rows_4,
|
b->rows_4,
|
||||||
b->rows_3,
|
b->rows_3,
|
||||||
b->rows_2,
|
b->rows_2,
|
||||||
clear
|
clear,
|
||||||
|
r_weights,
|
||||||
|
r_weights_stride
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = GPTQ_BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
blockDim.z = 1;
|
blockDim.z = 1;
|
||||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
|
||||||
gridDim.y = DIVIDE(size_m, m_count);
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
|
||||||
|
|
||||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
// DBGX((uint64_t) b->cuda_q_perm);
|
// DBGX((uint64_t) r_weights);
|
||||||
// DBGI(b->rows_4);
|
// if (r_weights)
|
||||||
// DBGI(b->height);
|
// print_global_mem(r_weights, 1, 1, 1);
|
||||||
|
// DBGI(r_weights_stride);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
|
|||||||
size_n,
|
size_n,
|
||||||
size_k,
|
size_k,
|
||||||
b->groups,
|
b->groups,
|
||||||
b->groupsize,
|
b->gptq_groupsize,
|
||||||
b->cuda_q_perm,
|
b->cuda_q_perm,
|
||||||
b->rows_4,
|
b->rows_4,
|
||||||
clear
|
clear,
|
||||||
|
r_weights,
|
||||||
|
r_weights_stride
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
|
|||||||
int size_k,
|
int size_k,
|
||||||
bool clear,
|
bool clear,
|
||||||
half* temp_dq,
|
half* temp_dq,
|
||||||
bool force_cuda
|
bool force_cuda,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride,
|
||||||
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||||
{
|
{
|
||||||
//printf("cublas\n");
|
|
||||||
|
|
||||||
// Reconstruct FP16 matrix, then cuBLAS
|
// Reconstruct FP16 matrix, then cuBLAS
|
||||||
|
|
||||||
if (!temp_dq) temp_dq = b->temp_dq;
|
if (!temp_dq) temp_dq = b->temp_dq;
|
||||||
@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
|
|||||||
//const float alpha = 1.0f;
|
//const float alpha = 1.0f;
|
||||||
//const float beta = clear ? 0.0f : 1.0f;
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
//cublasSgemmEx(cublas_handle,
|
//cublasSgemmEx(cublas_handle,
|
||||||
// CUBLAS_OP_N,
|
// CUBLAS_OP_N,
|
||||||
// CUBLAS_OP_N,
|
// CUBLAS_OP_N,
|
||||||
// size_n, size_m, size_k,
|
// size_n, size_m, size_k,
|
||||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||||
// a, CUDA_R_16F, size_k,
|
// a, CUDA_R_16F, size_k,
|
||||||
// &beta, c, CUDA_R_16F, size_n);
|
// &beta, c, CUDA_R_16F, size_n);
|
||||||
|
|
||||||
//const float alpha = 1.0f;
|
//const float alpha = 1.0f;
|
||||||
//const float beta = clear ? 0.0f : 1.0f;
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
//printf("cuda\n");
|
|
||||||
|
|
||||||
// Quantized matmul
|
// Quantized matmul
|
||||||
|
|
||||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
||||||
|
int max_chunks = size_m / block_m_size_max;
|
||||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
int last_chunk = max_chunks * block_m_size_max;
|
||||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
|
||||||
int last_chunk_size = size_m - last_chunk;
|
int last_chunk_size = size_m - last_chunk;
|
||||||
|
|
||||||
if (max_chunks)
|
if (max_chunks)
|
||||||
{
|
{
|
||||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (last_chunk_size)
|
if (last_chunk_size)
|
||||||
{
|
{
|
||||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -201,11 +210,10 @@ void clear_tensor_cuda
|
|||||||
int size_n
|
int size_n
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
return;
|
// dim3 blockDim, gridDim;
|
||||||
dim3 blockDim, gridDim;
|
// blockDim.x = CLEAR_N_SIZE;
|
||||||
blockDim.x = CLEAR_N_SIZE;
|
// blockDim.y = 1;
|
||||||
blockDim.y = 1;
|
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
// gridDim.y = size_m;
|
||||||
gridDim.y = size_m;
|
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
|
|||||||
int size_k,
|
int size_k,
|
||||||
bool clear = false,
|
bool clear = false,
|
||||||
half* reconstruct = NULL,
|
half* reconstruct = NULL,
|
||||||
bool force_cuda = false
|
bool force_cuda = false,
|
||||||
|
const half* r_weights = NULL,
|
||||||
|
const int r_weights_stride = 0,
|
||||||
|
bool mul_r_weights = false
|
||||||
);
|
);
|
||||||
|
|
||||||
void clear_tensor_cuda
|
void clear_tensor_cuda
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
#include "compat.cuh"
|
#include "compat.cuh"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
{
|
{
|
||||||
half2 result = {};
|
half2 result = {};
|
||||||
@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
|
|||||||
return fma(result_f, qs_f, g_result);
|
return fma(result_f, qs_f, g_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
|
||||||
|
|
||||||
|
float result = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
half2 w01 = dq[i];
|
||||||
|
float w0 = __low2float(w01);
|
||||||
|
float w1 = __high2float(w01);
|
||||||
|
float x0 = __half2float(*a_ptr++);
|
||||||
|
float x1 = __half2float(*a_ptr++);
|
||||||
|
result = fma(w0, x0, result);
|
||||||
|
result = fma(w1, x1, result);
|
||||||
|
}
|
||||||
|
float qs = __half2float(qs_h);
|
||||||
|
result *= qs;
|
||||||
|
half result_h = __float2half_rn(result);
|
||||||
|
return __hadd(result_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||||
|
return __hfma(result_h, qs_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||||
|
return __hfma(result_h, qs_h, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
typedef void (*fp_gemm_half_q_half_kernel)
|
typedef void (*fp_gemm_half_q_half_kernel)
|
||||||
@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
|||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const uint16_t*,
|
||||||
const uint16_t*,
|
const uint16_t*,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
|
|||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const int,
|
const int,
|
||||||
const bool
|
const bool,
|
||||||
|
const half*,
|
||||||
|
const int
|
||||||
);
|
);
|
||||||
|
|
||||||
template <bool first_block, int m_count>
|
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||||
__global__ void gemm_half_q_half_kernel
|
__global__ void gemm_half_q_half_kernel
|
||||||
(
|
(
|
||||||
const half* __restrict__ a,
|
const half* __restrict__ a,
|
||||||
@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
const int size_n,
|
const int size_n,
|
||||||
const int size_k,
|
const int size_k,
|
||||||
const int groups,
|
const int groups,
|
||||||
const int groupsize,
|
const uint16_t* __restrict__ b_q_group_map,
|
||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const int rows_8,
|
const int rows_8,
|
||||||
const int rows_6,
|
const int rows_6,
|
||||||
@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
const int rows_4,
|
const int rows_4,
|
||||||
const int rows_3,
|
const int rows_3,
|
||||||
const int rows_2,
|
const int rows_2,
|
||||||
const bool clear
|
const bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
MatrixView_half a_(a, size_m, size_k);
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
// Block
|
// Block
|
||||||
|
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
int offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
|
||||||
int end_m = min(offset_m + m_count, size_m);
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
|
||||||
int n = offset_n + t * 4;
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Read weights
|
||||||
|
|
||||||
|
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||||
|
if constexpr (use_r_weights)
|
||||||
|
{
|
||||||
|
uint16_t any_w = 0;
|
||||||
|
const half* w_ptr = r_weights;
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
weights[m].as_half = *w_ptr;
|
||||||
|
w_ptr += r_weights_stride;
|
||||||
|
any_w |= weights[m].as_uint16;
|
||||||
|
}
|
||||||
|
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||||
|
}
|
||||||
|
|
||||||
// Preload block_a
|
// Preload block_a
|
||||||
|
|
||||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
|
||||||
|
|
||||||
if (offset_k + t < end_k)
|
if (offset_k + t < end_k)
|
||||||
{
|
{
|
||||||
@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||||
half* block_a_ptr = block_a[m];
|
half* block_a_ptr = block_a[m];
|
||||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||||
|
// half a0 = a_ptr[offset_k + t];
|
||||||
block_a_ptr[t] = a0;
|
block_a_ptr[t] = a0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
// Find initial group
|
// Find initial group
|
||||||
|
|
||||||
int group = offset_k / groupsize;
|
//int group = offset_k / groupsize;
|
||||||
|
int group = b_q_group_map[offset_k * 2];
|
||||||
|
|
||||||
|
// if (offset_m == 0 && t == 0)
|
||||||
|
// DBGI2(offset_k, group);
|
||||||
|
|
||||||
// Preload scales
|
// Preload scales
|
||||||
|
|
||||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
|
||||||
|
|
||||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||||
for (int g = 0; g < groups_in_block; g++)
|
int temp_k = offset_k;
|
||||||
|
for (int g = 0; temp_k < end_k; g++)
|
||||||
{
|
{
|
||||||
int qscales[4];
|
int qscales[4];
|
||||||
b_q_scale_.item4(qscales, group + g, n);
|
b_q_scale_.item4(qscales, group + g, n);
|
||||||
@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
qscales[1]++;
|
qscales[1]++;
|
||||||
qscales[2]++;
|
qscales[2]++;
|
||||||
qscales[3]++;
|
qscales[3]++;
|
||||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
half maxscale = b_q_scale_max[group + g];
|
||||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
|
||||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
|
||||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
|
||||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
|
||||||
|
temp_k += b_q_group_map[temp_k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// a, b offset
|
// a, b offset
|
||||||
@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
const half* a_ptr = &block_a[0][0];
|
const half* a_ptr = &block_a[0][0];
|
||||||
int a_stride = BLOCK_KN_SIZE;
|
int a_stride = EXL2_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
// Initial group
|
// Initial group
|
||||||
|
|
||||||
int scales_idx = 0;
|
int scales_idx = 0;
|
||||||
float qs_f0 = scales[scales_idx][0];
|
half qs_h0 = scales[scales_idx][0];
|
||||||
float qs_f1 = scales[scales_idx][1];
|
half qs_h1 = scales[scales_idx][1];
|
||||||
float qs_f2 = scales[scales_idx][2];
|
half qs_h2 = scales[scales_idx][2];
|
||||||
float qs_f3 = scales[scales_idx][3];
|
half qs_h3 = scales[scales_idx][3];
|
||||||
int nextgroup = offset_k + groupsize;
|
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||||
|
|
||||||
// Column result
|
// Column result
|
||||||
|
|
||||||
float block_c[m_count][4] = {};
|
half block_c[m_count][4] = {};
|
||||||
|
|
||||||
// Dequantize groups
|
// Dequantize groups
|
||||||
|
|
||||||
@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 8;
|
a_ptr += 8;
|
||||||
}
|
}
|
||||||
@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 16;
|
a_ptr += 16;
|
||||||
}
|
}
|
||||||
@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 32;
|
a_ptr += 32;
|
||||||
}
|
}
|
||||||
@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 8;
|
a_ptr += 8;
|
||||||
}
|
}
|
||||||
@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
a_ptr += 32;
|
a_ptr += 32;
|
||||||
}
|
}
|
||||||
@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
{
|
{
|
||||||
group++;
|
group++;
|
||||||
scales_idx++;
|
scales_idx++;
|
||||||
qs_f0 = scales[scales_idx][0];
|
qs_h0 = scales[scales_idx][0];
|
||||||
qs_f1 = scales[scales_idx][1];
|
qs_h1 = scales[scales_idx][1];
|
||||||
qs_f2 = scales[scales_idx][2];
|
qs_h2 = scales[scales_idx][2];
|
||||||
qs_f3 = scales[scales_idx][3];
|
qs_h3 = scales[scales_idx][3];
|
||||||
nextgroup += groupsize;
|
nextgroup += b_q_group_map[k * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 2; j++)
|
for (int j = 0; j < 1; j++)
|
||||||
{
|
{
|
||||||
int4 load_int4[1];
|
int4 load_int4[1];
|
||||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
|
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||||
|
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||||
}
|
}
|
||||||
|
|
||||||
a_ptr += 16;
|
a_ptr += 16;
|
||||||
}
|
}
|
||||||
k += 32;
|
k += 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate column sums in c
|
// Accumulate column sums in c
|
||||||
@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
|
|||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
|
||||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
|
||||||
|
|
||||||
|
if constexpr (mul_r_weights)
|
||||||
|
{
|
||||||
|
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||||
|
result01 = __hmul2(result01, w_mul2);
|
||||||
|
result23 = __hmul2(result23, w_mul2);
|
||||||
|
}
|
||||||
|
|
||||||
atomicAdd(out , result01);
|
atomicAdd(out , result01);
|
||||||
atomicAdd(out + 1, result23);
|
atomicAdd(out + 1, result23);
|
||||||
|
// *out = result01;
|
||||||
|
// *(out + 1) = result23;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
template <bool use_r_weights, bool mul_r_weights>
|
||||||
|
struct map_m_count_exl2 {
|
||||||
|
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
|
||||||
|
{
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if EXL2_BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||||
{
|
{
|
||||||
#if BLOCK_M_SIZE_MAX >= 1
|
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
#endif
|
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
#if BLOCK_M_SIZE_MAX >= 2
|
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 3
|
|
||||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 4
|
|
||||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 5
|
|
||||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 6
|
|
||||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 7
|
|
||||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 8
|
|
||||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
|
||||||
#endif
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
|||||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||||
(
|
(
|
||||||
const half*,
|
const half*,
|
||||||
@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
|||||||
const int,
|
const int,
|
||||||
const uint16_t*,
|
const uint16_t*,
|
||||||
const int,
|
const int,
|
||||||
const bool
|
const bool,
|
||||||
|
const half*,
|
||||||
|
const int
|
||||||
);
|
);
|
||||||
|
|
||||||
template <bool first_block, int m_count>
|
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||||
__global__ void gemm_half_q_half_gptq_kernel
|
__global__ void gemm_half_q_half_gptq_kernel
|
||||||
(
|
(
|
||||||
const half* __restrict__ a,
|
const half* __restrict__ a,
|
||||||
@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
const int groupsize,
|
const int groupsize,
|
||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const int rows_4,
|
const int rows_4,
|
||||||
const bool clear
|
const bool clear,
|
||||||
|
const half* r_weights,
|
||||||
|
const int r_weights_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
MatrixView_half a_(a, size_m, size_k);
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
|
|
||||||
// Block
|
// Block
|
||||||
|
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
int offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
|
||||||
int end_m = min(offset_m + m_count, size_m);
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
int n = offset_n + t * 4;
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Read weights
|
||||||
|
|
||||||
|
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||||
|
if constexpr (use_r_weights)
|
||||||
|
{
|
||||||
|
uint16_t any_w = 0;
|
||||||
|
const half* w_ptr = r_weights;
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
weights[m].as_half = *w_ptr;
|
||||||
|
w_ptr += r_weights_stride;
|
||||||
|
any_w |= weights[m].as_uint16;
|
||||||
|
}
|
||||||
|
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||||
|
}
|
||||||
|
|
||||||
// Preload block_a
|
// Preload block_a
|
||||||
|
|
||||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
|
||||||
|
|
||||||
if (offset_k + t < end_k)
|
if (offset_k + t < end_k)
|
||||||
{
|
{
|
||||||
@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
|
|
||||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
const half* a_ptr = &block_a[0][0];
|
const half* a_ptr = &block_a[0][0];
|
||||||
int a_stride = BLOCK_KN_SIZE;
|
int a_stride = GPTQ_BLOCK_KN_SIZE;
|
||||||
|
|
||||||
// Initial group
|
// Initial group
|
||||||
|
|
||||||
int zeros[4];
|
int zeros[4];
|
||||||
float scales[4];
|
half2 scales[4];
|
||||||
half2 z1z16[4][2];
|
half2 z1z16[4][2];
|
||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_f(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
|
|
||||||
// Column result
|
// Column result
|
||||||
|
|
||||||
float block_c[m_count][4] = {};
|
half2 block_c[m_count][4] = {};
|
||||||
|
|
||||||
// Dequantize and multiply
|
// Dequantize and multiply
|
||||||
|
|
||||||
@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
group++;
|
group++;
|
||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_f(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||||
|
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
b_ptr += size_n;
|
b_ptr += size_n;
|
||||||
@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||||||
for (int m = 0; m < m_count; m++)
|
for (int m = 0; m < m_count; m++)
|
||||||
{
|
{
|
||||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
|
||||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
|
||||||
|
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
|
||||||
|
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
|
||||||
|
half2 result01 = __halves2half2(result0, result1);
|
||||||
|
half2 result23 = __halves2half2(result2, result3);
|
||||||
|
|
||||||
|
if constexpr (mul_r_weights)
|
||||||
|
{
|
||||||
|
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||||
|
result01 = __hmul2(result01, w_mul2);
|
||||||
|
result23 = __hmul2(result23, w_mul2);
|
||||||
|
}
|
||||||
|
|
||||||
atomicAdd(out , result01);
|
atomicAdd(out , result01);
|
||||||
atomicAdd(out + 1, result23);
|
atomicAdd(out + 1, result23);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
template <bool use_r_weights, bool mul_r_weights>
|
||||||
|
struct map_m_count_gptq {
|
||||||
|
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
|
||||||
|
{
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||||
{
|
{
|
||||||
#if BLOCK_M_SIZE_MAX >= 1
|
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
#endif
|
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
#if BLOCK_M_SIZE_MAX >= 2
|
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 3
|
|
||||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 4
|
|
||||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 5
|
|
||||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 6
|
|
||||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 7
|
|
||||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
|
||||||
#endif
|
|
||||||
#if BLOCK_M_SIZE_MAX >= 8
|
|
||||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
|
||||||
#endif
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,7 @@ QMatrix::QMatrix
|
|||||||
uint32_t* _q_scale,
|
uint32_t* _q_scale,
|
||||||
half* _q_scale_max,
|
half* _q_scale_max,
|
||||||
uint16_t* _q_groups,
|
uint16_t* _q_groups,
|
||||||
|
uint16_t* _q_group_map,
|
||||||
|
|
||||||
uint32_t* _gptq_qzeros,
|
uint32_t* _gptq_qzeros,
|
||||||
half* _gptq_scales,
|
half* _gptq_scales,
|
||||||
@ -80,13 +81,17 @@ QMatrix::QMatrix
|
|||||||
cuda_q_scale = _q_scale;
|
cuda_q_scale = _q_scale;
|
||||||
cuda_q_scale_max = _q_scale_max;
|
cuda_q_scale_max = _q_scale_max;
|
||||||
cuda_q_groups = _q_groups;
|
cuda_q_groups = _q_groups;
|
||||||
|
cuda_q_group_map = _q_group_map;
|
||||||
cuda_gptq_qzeros = _gptq_qzeros;
|
cuda_gptq_qzeros = _gptq_qzeros;
|
||||||
cuda_gptq_scales = _gptq_scales;
|
cuda_gptq_scales = _gptq_scales;
|
||||||
|
|
||||||
is_gptq = (_gptq_qzeros != NULL);
|
is_gptq = (_gptq_qzeros != NULL);
|
||||||
|
|
||||||
groupsize = 1;
|
if (is_gptq)
|
||||||
while (groupsize * groups < height) groupsize *= 2;
|
{
|
||||||
|
gptq_groupsize = 1;
|
||||||
|
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Create group map
|
// Create group map
|
||||||
|
|
||||||
@ -102,15 +107,26 @@ QMatrix::QMatrix
|
|||||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||||
|
|
||||||
|
int row = 0;
|
||||||
for (int i = 0; i < groups; i++)
|
for (int i = 0; i < groups; i++)
|
||||||
{
|
{
|
||||||
int bits = cpu_q_groups[i * 2];
|
int bits = cpu_q_groups[i * 2];
|
||||||
if (bits == 8) rows_8 += groupsize;
|
|
||||||
if (bits == 6) rows_6 += groupsize;
|
int rows;
|
||||||
if (bits == 5) rows_5 += groupsize;
|
if (i < groups - 1)
|
||||||
if (bits == 4) rows_4 += groupsize;
|
{
|
||||||
if (bits == 3) rows_3 += groupsize;
|
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
|
||||||
if (bits == 2) rows_2 += groupsize;
|
rows = qrows * 32 / bits;
|
||||||
|
}
|
||||||
|
else rows = height - row;
|
||||||
|
|
||||||
|
if (bits == 8) rows_8 += rows;
|
||||||
|
if (bits == 6) rows_6 += rows;
|
||||||
|
if (bits == 5) rows_5 += rows;
|
||||||
|
if (bits == 4) rows_4 += rows;
|
||||||
|
if (bits == 3) rows_3 += rows;
|
||||||
|
if (bits == 2) rows_2 += rows;
|
||||||
|
row += rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
free(cpu_q_groups);
|
free(cpu_q_groups);
|
||||||
@ -138,6 +154,13 @@ QMatrix::QMatrix
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DBGI(rows_8);
|
||||||
|
// DBGI(rows_6);
|
||||||
|
// DBGI(rows_5);
|
||||||
|
// DBGI(rows_4);
|
||||||
|
// DBGI(rows_3);
|
||||||
|
// DBGI(rows_2);
|
||||||
|
|
||||||
// Shuffle quantized data
|
// Shuffle quantized data
|
||||||
|
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
|
|||||||
const uint16_t* __restrict__ b_q_perm,
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
const uint32_t* __restrict__ b_q_scale,
|
const uint32_t* __restrict__ b_q_scale,
|
||||||
const half* __restrict__ b_q_scale_max,
|
const half* __restrict__ b_q_scale_max,
|
||||||
//const uint16_t* __restrict__ b_q_groups,
|
const uint16_t* __restrict__ b_q_group_map,
|
||||||
const int size_k,
|
const int size_k,
|
||||||
const int size_n,
|
const int size_n,
|
||||||
const int groupsize,
|
//const int groupsize,
|
||||||
const int groups,
|
const int groups,
|
||||||
half* __restrict__ b,
|
half* __restrict__ b,
|
||||||
const int rows_8,
|
const int rows_8,
|
||||||
@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
// Find initial group
|
// Find initial group
|
||||||
|
|
||||||
int group = offset_k / groupsize;
|
// int group = offset_k / groupsize;
|
||||||
|
int group = b_q_group_map[offset_k * 2];
|
||||||
|
|
||||||
int pre_rows_8 = min(rows_8, offset_k);
|
int pre_rows_8 = min(rows_8, offset_k);
|
||||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||||
@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||||
int nextgroup = offset_k + groupsize;
|
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
int k = offset_k;
|
int k = offset_k;
|
||||||
@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_8 && k < end_k)
|
while (k < rows_8 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
{
|
{
|
||||||
half2 dq[4];
|
half2 dq[4];
|
||||||
@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_6 && k < end_k)
|
while (k < rows_6 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 2; p++)
|
for (int p = 0; p < 2; p++)
|
||||||
{
|
{
|
||||||
half2 dq[8];
|
half2 dq[8];
|
||||||
@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_5 && k < end_k)
|
while (k < rows_5 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 1; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[16];
|
half2 dq[16];
|
||||||
@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_4 && k < end_k)
|
while (k < rows_4 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
{
|
{
|
||||||
half2 dq[4];
|
half2 dq[4];
|
||||||
@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_3 && k < end_k)
|
while (k < rows_3 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 1; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[16];
|
half2 dq[16];
|
||||||
@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
|
|||||||
|
|
||||||
while (k < rows_2 && k < end_k)
|
while (k < rows_2 && k < end_k)
|
||||||
{
|
{
|
||||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
for (int p = 0; p < 2; p++)
|
for (int p = 0; p < 1; p++)
|
||||||
{
|
{
|
||||||
half2 dq[8];
|
half2 dq[8];
|
||||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
|
|||||||
half* dqh = (half*) dq;
|
half* dqh = (half*) dq;
|
||||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
}
|
}
|
||||||
k += 32;
|
k += 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
|
|||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
cuda_q_scale,
|
cuda_q_scale,
|
||||||
cuda_q_scale_max,
|
cuda_q_scale_max,
|
||||||
//cuda_q_groups,
|
cuda_q_group_map,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
groupsize,
|
//groupsize,
|
||||||
groups,
|
groups,
|
||||||
out,
|
out,
|
||||||
rows_8,
|
rows_8,
|
||||||
@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
|
|||||||
//const uint16_t* __restrict__ b_q_groups,
|
//const uint16_t* __restrict__ b_q_groups,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
groupsize,
|
gptq_groupsize,
|
||||||
groups,
|
groups,
|
||||||
out,
|
out,
|
||||||
rows_4
|
rows_4
|
||||||
|
@ -18,7 +18,7 @@ public:
|
|||||||
int height;
|
int height;
|
||||||
int width;
|
int width;
|
||||||
int groups;
|
int groups;
|
||||||
int groupsize;
|
int gptq_groupsize;
|
||||||
|
|
||||||
int rows_8;
|
int rows_8;
|
||||||
int rows_6;
|
int rows_6;
|
||||||
@ -33,6 +33,7 @@ public:
|
|||||||
uint32_t* cuda_q_scale = NULL;
|
uint32_t* cuda_q_scale = NULL;
|
||||||
half* cuda_q_scale_max = NULL;
|
half* cuda_q_scale_max = NULL;
|
||||||
uint16_t* cuda_q_groups = NULL;
|
uint16_t* cuda_q_groups = NULL;
|
||||||
|
uint16_t* cuda_q_group_map = NULL;
|
||||||
uint32_t* cuda_gptq_qzeros = NULL;
|
uint32_t* cuda_gptq_qzeros = NULL;
|
||||||
half* cuda_gptq_scales = NULL;
|
half* cuda_gptq_scales = NULL;
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ public:
|
|||||||
uint32_t* _q_scale,
|
uint32_t* _q_scale,
|
||||||
half* _q_scale_max,
|
half* _q_scale_max,
|
||||||
uint16_t* _q_groups,
|
uint16_t* _q_groups,
|
||||||
|
uint16_t* _q_group_map,
|
||||||
|
|
||||||
uint32_t* _gptq_qzeros,
|
uint32_t* _gptq_qzeros,
|
||||||
half* _gptq_scales,
|
half* _gptq_scales,
|
||||||
|
@ -7,6 +7,7 @@ union half2_uint32
|
|||||||
half2 as_half2;
|
half2 as_half2;
|
||||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
|
__device__ half2_uint32() : as_uint32(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
union half_uint16
|
union half_uint16
|
||||||
@ -15,6 +16,7 @@ union half_uint16
|
|||||||
half as_half;
|
half as_half;
|
||||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
__device__ half_uint16(half val) : as_half(val) {}
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
|
__device__ half_uint16() : as_uint16(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Max_scale premultiplied by 1/256
|
// Max_scale premultiplied by 1/256
|
||||||
|
@ -1,3 +1,11 @@
|
|||||||
|
#ifndef _util_cuh
|
||||||
|
#define _util_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||||
|
|
||||||
@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
|
|||||||
if (abort) exit(code);
|
if (abort) exit(code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_global_mem(const half* ptr, int rows, int columns, int stride);
|
||||||
|
|
||||||
|
#endif
|
@ -31,6 +31,7 @@ uintptr_t make_q_matrix
|
|||||||
torch::Tensor q_scale,
|
torch::Tensor q_scale,
|
||||||
torch::Tensor q_scale_max,
|
torch::Tensor q_scale_max,
|
||||||
torch::Tensor q_groups,
|
torch::Tensor q_groups,
|
||||||
|
torch::Tensor q_group_map,
|
||||||
torch::Tensor gptq_qzeros,
|
torch::Tensor gptq_qzeros,
|
||||||
torch::Tensor gptq_scales,
|
torch::Tensor gptq_scales,
|
||||||
torch::Tensor gptq_g_idx,
|
torch::Tensor gptq_g_idx,
|
||||||
@ -43,6 +44,7 @@ uintptr_t make_q_matrix
|
|||||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||||
@ -83,12 +85,15 @@ uintptr_t make_q_matrix
|
|||||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||||
|
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
|
||||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||||
(half*) temp_dq.data_ptr()
|
(half*) temp_dq.data_ptr()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||||
|
|
||||||
return reinterpret_cast<uintptr_t> (m);
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,10 +32,10 @@ def fresh_cache():
|
|||||||
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
||||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
||||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
||||||
os.environ['HUGGINGFACE_HUB_CACHE'] = d
|
os.environ["HUGGINGFACE_HUB_CACHE"] = d
|
||||||
yield
|
yield
|
||||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
|
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
|
||||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def prefetched():
|
|||||||
revision="main",
|
revision="main",
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
allow_patterns=["*.safetensors"]
|
allow_patterns=["*.safetensors"],
|
||||||
)
|
)
|
||||||
yield model_id
|
yield model_id
|
||||||
|
|
||||||
@ -61,7 +61,15 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
|
|||||||
def test_weight_hub_files_offline_ok(prefetched, offline):
|
def test_weight_hub_files_offline_ok(prefetched, offline):
|
||||||
# If the model is prefetched then we should be able to get the weight files from local cache
|
# If the model is prefetched then we should be able to get the weight files from local cache
|
||||||
filenames = weight_hub_files(prefetched)
|
filenames = weight_hub_files(prefetched)
|
||||||
assert filenames == ['model.safetensors']
|
root = None
|
||||||
|
assert len(filenames) == 1
|
||||||
|
for f in filenames:
|
||||||
|
curroot, filename = os.path.split(f)
|
||||||
|
if root is None:
|
||||||
|
root = curroot
|
||||||
|
else:
|
||||||
|
assert root == curroot
|
||||||
|
assert filename == "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
|
@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
|
|||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
bits, groupsize = weights._get_gptq_params()
|
bits, groupsize, _ = weights._get_gptq_params()
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
|
@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|||||||
return output.view(output_shape)
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# Group map needed for irregular group sizes
|
||||||
|
|
||||||
|
|
||||||
|
def make_group_map(q_groups, num_qrows):
|
||||||
|
|
||||||
|
gr = q_groups.tolist()
|
||||||
|
group_map = []
|
||||||
|
num_groups = len(gr) // 2
|
||||||
|
|
||||||
|
for i in range(num_groups):
|
||||||
|
bits = gr[i * 2]
|
||||||
|
if i < num_groups - 1:
|
||||||
|
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
||||||
|
else:
|
||||||
|
qrows = num_qrows - gr[i * 2 + 1]
|
||||||
|
rows = qrows * 32 // bits
|
||||||
|
for j in range(rows):
|
||||||
|
group_map += [i]
|
||||||
|
group_map += [rows - j]
|
||||||
|
|
||||||
|
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
||||||
|
|
||||||
|
|
||||||
|
# Create Q matrix
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
"""
|
"""
|
||||||
Create Q matrix
|
Create Q matrix
|
||||||
@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
w["q_scale_max"] /= 256
|
w["q_scale_max"] /= 256
|
||||||
w["q_perm"] = w["q_perm"].short()
|
w["q_perm"] = w["q_perm"].short()
|
||||||
w["q_invperm"] = w["q_invperm"].short()
|
w["q_invperm"] = w["q_invperm"].short()
|
||||||
|
|
||||||
|
if "q_group_map" not in w:
|
||||||
|
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
|
||||||
|
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w["q_weight"],
|
w["q_weight"],
|
||||||
w["q_perm"],
|
w["q_perm"],
|
||||||
@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
w["q_scale"],
|
w["q_scale"],
|
||||||
w["q_scale_max"],
|
w["q_scale_max"],
|
||||||
w["q_groups"],
|
w["q_groups"],
|
||||||
|
w["q_group_map"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
w["qzeros"],
|
w["qzeros"],
|
||||||
w["scales"],
|
w["scales"],
|
||||||
w["g_idx"].cpu(),
|
w["g_idx"].cpu(),
|
||||||
@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
w["qzeros"],
|
w["qzeros"],
|
||||||
w["scales"],
|
w["scales"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
|
@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|||||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
|
def _cached_weight_files(
|
||||||
|
model_id: str, revision: Optional[str], extension: str
|
||||||
|
) -> List[str]:
|
||||||
"""Guess weight files from the cached revision snapshot directory"""
|
"""Guess weight files from the cached revision snapshot directory"""
|
||||||
d = _get_cached_revision_directory(model_id, revision)
|
d = _get_cached_revision_directory(model_id, revision)
|
||||||
if not d:
|
if not d:
|
||||||
@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
|
|||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
|
def _weight_hub_files_from_model_info(
|
||||||
|
info: hf_api.ModelInfo, extension: str
|
||||||
|
) -> List[str]:
|
||||||
return [
|
return [
|
||||||
s.rfilename
|
s.rfilename
|
||||||
for s in info.siblings
|
for s in info.siblings
|
||||||
@ -44,20 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||||||
# see _weight_hub_files_from_model_info, that's also what is
|
# see _weight_hub_files_from_model_info, that's also what is
|
||||||
# done there with the len(s.rfilename.split("/")) == 1 condition
|
# done there with the len(s.rfilename.split("/")) == 1 condition
|
||||||
root, _, files = next(os.walk(str(d)))
|
root, _, files = next(os.walk(str(d)))
|
||||||
filenames = [f for f in files
|
filenames = [
|
||||||
if f.endswith(extension)
|
os.path.join(root, f)
|
||||||
and "arguments" not in f
|
for f in files
|
||||||
and "args" not in f
|
if f.endswith(extension)
|
||||||
and "training" not in f]
|
and "arguments" not in f
|
||||||
|
and "args" not in f
|
||||||
|
and "adapter" not in f
|
||||||
|
and "training" not in f
|
||||||
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
|
def _get_cached_revision_directory(
|
||||||
|
model_id: str, revision: Optional[str]
|
||||||
|
) -> Optional[Path]:
|
||||||
if revision is None:
|
if revision is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
||||||
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
|
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
|
||||||
|
)
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
if not repo_cache.is_dir():
|
||||||
# No cache for this model
|
# No cache for this model
|
||||||
@ -85,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
|
|||||||
|
|
||||||
|
|
||||||
def weight_hub_files(
|
def weight_hub_files(
|
||||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get the weights filenames on the hub"""
|
"""Get the weights filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
@ -19,6 +19,7 @@ from accelerate import init_empty_weights
|
|||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
|
|||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
logger.warning(
|
V2 = False
|
||||||
|
log_once(
|
||||||
|
logger.warning,
|
||||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
||||||
)
|
)
|
||||||
V2 = False
|
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
|
6
server/text_generation_server/utils/log.py
Normal file
6
server/text_generation_server/utils/log.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(10)
|
||||||
|
def log_once(log, msg:str):
|
||||||
|
log(msg)
|
@ -6,6 +6,7 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
@ -161,7 +162,7 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, _ = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
@ -211,10 +212,10 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, desc_act = self._get_gptq_params()
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
@ -240,11 +241,15 @@ class Weights:
|
|||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, desc_act = self._get_gptq_params()
|
||||||
|
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
@ -274,12 +279,18 @@ class Weights:
|
|||||||
if use_exllama:
|
if use_exllama:
|
||||||
if not HAS_EXLLAMA:
|
if not HAS_EXLLAMA:
|
||||||
if CAN_EXLLAMA:
|
if CAN_EXLLAMA:
|
||||||
logger.warning(
|
log_once(
|
||||||
|
logger.warning,
|
||||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||||
)
|
)
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
else:
|
else:
|
||||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
f"Using exllama kernels v{HAS_EXLLAMA}"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
if use_exllama and groupsize != -1:
|
if use_exllama and groupsize != -1:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
@ -288,14 +299,12 @@ class Weights:
|
|||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize, _ = self._get_gptq_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
@ -314,18 +323,20 @@ class Weights:
|
|||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
desc_act = False
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
bits = self.gptq_bits
|
bits = self.gptq_bits
|
||||||
groupsize = self.gptq_groupsize
|
groupsize = self.gptq_groupsize
|
||||||
|
desc_act = getattr(self, "gptq_desc_act", False)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return bits, groupsize
|
return bits, groupsize, desc_act
|
||||||
|
|
||||||
def _set_gptq_params(self, model_id, revision):
|
def _set_gptq_params(self, model_id, revision):
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
@ -340,6 +351,7 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
@ -353,6 +365,7 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
self.gptq_groupsize = data["group_size"]
|
self.gptq_groupsize = data["group_size"]
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
@ -366,5 +379,6 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
self.gptq_groupsize = data["q_group_size"]
|
self.gptq_groupsize = data["q_group_size"]
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user