diff --git a/Makefile b/Makefile index 3c2f2b9de..81b312d15 100644 --- a/Makefile +++ b/Makefile @@ -56,3 +56,6 @@ run-bloom: run-bloom-quantize: text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 + +clean: + rm -rf target aml diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2e2bc7a58..12c68539b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -20,6 +20,7 @@ mod env_runtime; enum Quantization { Bitsandbytes, Gptq, + Gptq_cuda, } impl std::fmt::Display for Quantization { @@ -32,10 +33,14 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Gptq_cuda => { + write!(f, "gptq-cuda") + } } } } + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu b/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu new file mode 100644 index 000000000..1219ece35 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu @@ -0,0 +1,69 @@ +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cuh b/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cuh new file mode 100644 index 000000000..dbe4ed168 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_buffers.cuh @@ -0,0 +1,50 @@ +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh b/server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh new file mode 100644 index 000000000..aa1ed39da --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh @@ -0,0 +1,56 @@ +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) +#if __CUDA_ARCH__ < 700 + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cu b/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cu new file mode 100644 index 000000000..65514c592 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cu @@ -0,0 +1,59 @@ +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cuh b/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cuh new file mode 100644 index 000000000..f115db737 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/column_remap.cuh @@ -0,0 +1,17 @@ +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cu b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cu new file mode 100644 index 000000000..60dc4c9db --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cu @@ -0,0 +1,252 @@ +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cuda_compat.cuh" +#include "../cuda_buffers.cuh" + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(acc.x, acc.y); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + +// const float alpha = 1.0f; +// const float beta = no_zero ? 1.0f : 0.0f; +// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, +// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +} diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cuh b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cuh new file mode 100644 index 000000000..2cf35e809 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matmul.cuh @@ -0,0 +1,35 @@ +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cu b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cu new file mode 100644 index 000000000..d02fa0ece --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cu @@ -0,0 +1,215 @@ +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} \ No newline at end of file diff --git a/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cuh b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cuh new file mode 100644 index 000000000..e8135d989 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/cuda_func/q4_matrix.cuh @@ -0,0 +1,51 @@ +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp b/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp new file mode 100644 index 000000000..49b24ba21 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp @@ -0,0 +1,247 @@ +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/server/custom_kernels/custom_kernels/exllama/matrix.cuh b/server/custom_kernels/custom_kernels/exllama/matrix.cuh new file mode 100644 index 000000000..4951a76fc --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/matrix.cuh @@ -0,0 +1,292 @@ +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/server/custom_kernels/custom_kernels/exllama/tuning.h b/server/custom_kernels/custom_kernels/exllama/tuning.h new file mode 100644 index 000000000..8ca692eec --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/tuning.h @@ -0,0 +1,11 @@ +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/server/custom_kernels/custom_kernels/exllama/util.cuh b/server/custom_kernels/custom_kernels/exllama/util.cuh new file mode 100644 index 000000000..dd1cfc7d0 --- /dev/null +++ b/server/custom_kernels/custom_kernels/exllama/util.cuh @@ -0,0 +1,27 @@ +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#define cudaUnspecified cudaErrorApiFailureBase + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index 43b8ee4ed..2af50d942 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -1,5 +1,5 @@ from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension setup( name="custom_kernels", @@ -14,6 +14,16 @@ setup( sources=["custom_kernels/fused_attention_cuda.cu"], extra_compile_args=["-arch=compute_80", "-std=c++17"], ), + CppExtension( + name="custom_kernels.exllama", + sources=[ + "custom_kernels/exllama/exllama_ext.cpp", + "custom_kernels/exllama/cuda_buffers.cu", + "custom_kernels/exllama/cuda_func/column_remap.cu", + "custom_kernels/exllama/cuda_func/q4_matmul.cu", + "custom_kernels/exllama/cuda_func/q4_matrix.cu" + ], + ) ], cmdclass={"build_ext": BuildExtension}, ) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index aeb1f13b7..4afa0183c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -14,6 +14,7 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" + gptq_cuda = "gptq-cuda" @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2abde6854..deb9e5d3d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -246,7 +246,7 @@ def get_model( if sharded: raise ValueError("sharded is not supported for AutoModel") - if quantize == "gptq": + if quantize in ["gptq", "gptq-cuda"]: raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4eb0034dd..0748d8ca6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -7,7 +7,6 @@ from typing import Optional # Flash attention imports import flash_attn_cuda - from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -17,15 +16,17 @@ from text_generation_server.utils.layers import ( get_linear, ) +from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear +from custom_kernels.exllama import prepare_buffers, set_tuning_params def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - - if config.quantize == "gptq": - return _load_multi_mqa_gptq( + if config.quantize in ["gptq", "gptq-cuda"]: + layer = _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) + return layer else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -87,7 +88,7 @@ def _load_multi_mqa_gptq( kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -95,7 +96,6 @@ def _load_multi_mqa_gptq( def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -160,7 +160,7 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) def load_col(config, prefix: str, weights, bias: bool): @@ -175,22 +175,40 @@ def load_col(config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) def load_row(config, prefix: str, weights, bias: bool): + quantize = config.quantize + if quantize == "gptq-cuda" and weights.process_group.size() > 1: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + groupsize = weights.get_tensor("gptq_groupsize").item() + + act_order = True + if g_idx is not None: + if torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) or (g_idx == 0).all(): + act_order = False + else: + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + quantize = "gptq" + if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix, quantize=quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None + + if quantize == "gptq-cuda" and not act_order: + weight[3] = None # remove g_idx to indicate to exllama that act-order is not used + return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias, quantize, device=weights.device), process_group=weights.process_group ) @@ -495,6 +513,30 @@ class FlashSantacoderForCausalLM(nn.Module): config, prefix="transformer.wte", weights=weights ) + # Buffers need to be persistent to avoid any bug. + self.buffers = {} + if config.quantize == "gptq-cuda": + max_dq_buffer_size = 0 + for name, submodule in self.named_modules(): + if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): + max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) + + intermediate_size = config.n_inner + max_seq_len = 2048 # TODO: we should be able to set it + + self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device) + self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device) + + prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"]) + + # TODO: ability to set them + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + def forward( self, input_ids, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a71c00613..5cbed5488 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -59,6 +59,7 @@ class FlashSantacoderSharded(FlashCausalLM): model = FlashSantacoderForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( model=model.to(device), tokenizer=tokenizer, diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 54fa20144..f1ba9bf9a 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -252,6 +252,7 @@ class QuantLinear(nn.Module): self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) self.register_buffer("g_idx", g_idx) + if bias is not None: self.register_buffer("bias", bias) else: @@ -357,3 +358,82 @@ class QuantLinear(nn.Module): ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) + +import torch +from custom_kernels.exllama import make_q4, q4_matmul + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device = "meta") + +def ext_make_q4(qweight, qzeros, scales, g_idx, device): + """Construct Q4Matrix, return handle""" + return make_q4(qweight, + qzeros, + scales, + g_idx if g_idx is not None else none_tensor, + device) + +def ext_q4_matmul(x, q4, q4_width): + """Matrix multiplication, returns x @ q4""" + outshape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) + + q4_matmul(x, q4, output) + + return output.view(outshape) + + +class Ex4bitLinear: + """Linear layer implementation with per-group 4-bit quantization of the weights""" + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size: int): + assert bits == 4 + + self.device = device + self.qweight = qweight.to(device) + self.qzeros = qzeros.to(device) + self.scales = scales.to(device) + self.g_idx = g_idx.cpu() if g_idx is not None else None + self.bias = bias.to(device) if bias is not None else None + + if self.g_idx is not None and (self.g_idx == 0).all(): + self.empty_g_idx = True + self.g_idx = None + + assert device.type == "cuda" + assert device.index is not None + + self.q4 = ext_make_q4( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + device.index + ) + + self.height = qweight.shape[0] * 8 + self.width = qweight.shape[1] + + # Infer groupsize from height of qzeros + self.groupsize = None + if self.qzeros.shape[0] > 1: + if world_size is None: + world_size = 1 + # self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0] // world_size) + self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) + + assert groupsize == self.groupsize + + # Handle act-order matrix + if self.g_idx is not None: + if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") + self.act_order = True + else: + self.act_order = False + + def forward(self, x): + out = ext_q4_matmul(x, self.q4, self.width) + + if self.bias is not None: + out.add_(self.bias) + return out diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a2b0c7396..dd399ee3d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -15,8 +15,9 @@ except ImportError: from accelerate import init_empty_weights -from text_generation_server.utils.gptq.quant_linear import QuantLinear +from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear +from typing import Optional # Monkey patching @classmethod @@ -118,7 +119,7 @@ class Linear8bitLt(nn.Module): return out -def get_linear(weight, bias, quantize): +def get_linear(weight, bias, quantize, device = None): if quantize is None: linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": @@ -147,6 +148,15 @@ def get_linear(weight, bias, quantize): bits, groupsize, ) + elif quantize == "gptq-cuda": + try: + qweight, qzeros, scales, g_idx, bits, groupsize = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." + ) + + linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -171,12 +181,12 @@ class TensorParallelHead(SuperLayer): weight = weights.get_sharded(f"{prefix}.weight", dim=0) # GPTQ doesn't quantize heads (nor embeddings) - if config.quantize == "gptq": + if config.quantize in ["gptq", "gptq-cuda"]: quantize = None else: quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), + get_linear(weight, bias=None, quantize=quantize, device=weights.device), process_group=weights.process_group, ) @@ -232,7 +242,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias, config.quantize, device=weights.device) return cls(linear) @@ -251,7 +261,7 @@ class TensorParallelRowLinear(SuperLayer): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias, config.quantize, device=weights.device), process_group=weights.process_group, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d3718349..1564bdee4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,7 +3,6 @@ from typing import List, Dict, Optional from safetensors import safe_open import torch - class Weights: def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): routing = {} @@ -43,7 +42,7 @@ class Weights: return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename, tensor_name= self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -92,7 +91,7 @@ class Weights: return tensor def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "gptq": + if quantize in ["gptq", "gptq-cuda"]: try: qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) except RuntimeError: @@ -107,26 +106,39 @@ class Weights: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + weight = [qweight, qzeros, scales, g_idx, bits, groupsize] else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "gptq": + def get_multi_weights_row(self, prefix: str, quantize: str): + if quantize in ["gptq", "gptq-cuda"]: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + if quantize == "gptq": + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + else: + # Exllama reorders the weights in advance and the activations on the fly, thus + # the scales and zero-points do not need to be reordered + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + + # For tp > 1, at this point we know we do not use act-order + if self.process_group.size() == 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + weight = [qweight, qzeros, scales, g_idx, bits, groupsize] else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight