mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
add exllama gptq kernel
This commit is contained in:
parent
70f485bf9f
commit
ee7ba48b9a
3
Makefile
3
Makefile
@ -56,3 +56,6 @@ run-bloom:
|
|||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf target aml
|
||||||
|
@ -20,6 +20,7 @@ mod env_runtime;
|
|||||||
enum Quantization {
|
enum Quantization {
|
||||||
Bitsandbytes,
|
Bitsandbytes,
|
||||||
Gptq,
|
Gptq,
|
||||||
|
Gptq_cuda,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -32,10 +33,14 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
|
Quantization::Gptq_cuda => {
|
||||||
|
write!(f, "gptq-cuda")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
69
server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu
Normal file
69
server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#define _cuda_buffers_cu
|
||||||
|
#include "cuda_buffers.cuh"
|
||||||
|
|
||||||
|
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||||
|
// __constant__ half2 q4_table[16][256];
|
||||||
|
// half2 q4_table_host[16][256];
|
||||||
|
// bool q4_table_init = false;
|
||||||
|
|
||||||
|
CudaBuffers::CudaBuffers
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
) :
|
||||||
|
device(_device),
|
||||||
|
temp_state(_temp_state),
|
||||||
|
temp_dq(_temp_dq)
|
||||||
|
{
|
||||||
|
cudaSetDevice(_device);
|
||||||
|
|
||||||
|
cudaStreamCreate(&alt_stream_1);
|
||||||
|
cudaStreamCreate(&alt_stream_2);
|
||||||
|
cudaStreamCreate(&alt_stream_3);
|
||||||
|
cudaEventCreate(&alt_stream_1_done);
|
||||||
|
cudaEventCreate(&alt_stream_2_done);
|
||||||
|
cudaEventCreate(&alt_stream_3_done);
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffers::~CudaBuffers()
|
||||||
|
{
|
||||||
|
cudaStreamDestroy(alt_stream_1);
|
||||||
|
cudaStreamDestroy(alt_stream_2);
|
||||||
|
cudaStreamDestroy(alt_stream_3);
|
||||||
|
cudaEventDestroy(alt_stream_1_done);
|
||||||
|
cudaEventDestroy(alt_stream_2_done);
|
||||||
|
cudaEventDestroy(alt_stream_3_done);
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffers* get_buffers(const int device_index)
|
||||||
|
{
|
||||||
|
return g_buffers[device_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
void prepare_buffers_cuda
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
)
|
||||||
|
{
|
||||||
|
CudaBuffers* buffers = new CudaBuffers
|
||||||
|
(
|
||||||
|
_device,
|
||||||
|
_temp_state,
|
||||||
|
_temp_dq
|
||||||
|
);
|
||||||
|
|
||||||
|
g_buffers[_device] = buffers;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cleanup_buffers_cuda()
|
||||||
|
{
|
||||||
|
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||||
|
{
|
||||||
|
if (!g_buffers[i]) continue;
|
||||||
|
delete g_buffers[i];
|
||||||
|
g_buffers[i] = NULL;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
#ifndef _cuda_buffers_cuh
|
||||||
|
#define _cuda_buffers_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
const int CUDA_MAX_DEVICES = 16;
|
||||||
|
|
||||||
|
// #ifndef _cuda_buffers_cu
|
||||||
|
// extern __constant__ half2 q4_table[16][256];
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
class CudaBuffers
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
int device;
|
||||||
|
|
||||||
|
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||||
|
half* temp_dq; // size of largest quant tensor * 8
|
||||||
|
|
||||||
|
cudaStream_t alt_stream_1;
|
||||||
|
cudaStream_t alt_stream_2;
|
||||||
|
cudaStream_t alt_stream_3;
|
||||||
|
cudaEvent_t alt_stream_1_done;
|
||||||
|
cudaEvent_t alt_stream_2_done;
|
||||||
|
cudaEvent_t alt_stream_3_done;
|
||||||
|
|
||||||
|
CudaBuffers
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
);
|
||||||
|
~CudaBuffers();
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaBuffers* get_buffers(const int device_index);
|
||||||
|
|
||||||
|
void prepare_buffers_cuda
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
);
|
||||||
|
|
||||||
|
void cleanup_buffers_cuda();
|
||||||
|
|
||||||
|
#endif
|
56
server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh
Normal file
56
server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#ifndef _cuda_compat_cuh
|
||||||
|
#define _cuda_compat_cuh
|
||||||
|
|
||||||
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||||
|
{
|
||||||
|
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
__half_raw hsum;
|
||||||
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
|
half tmpres = __hadd(hsum, val);
|
||||||
|
hsum = __half_raw(tmpres);
|
||||||
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||||
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||||
|
{
|
||||||
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
half2 old_val = *((half2*)&old);
|
||||||
|
half2 new_val = __hadd2(old_val, val);
|
||||||
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__)
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 600
|
||||||
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,59 @@
|
|||||||
|
#include "column_remap.cuh"
|
||||||
|
#include "../util.cuh"
|
||||||
|
|
||||||
|
const int SHUF_BLOCKSIZE_X = 256;
|
||||||
|
const int SHUF_BLOCKSIZE_Y = 16;
|
||||||
|
|
||||||
|
__global__ void column_remap_kernel
|
||||||
|
(
|
||||||
|
const half* __restrict__ x,
|
||||||
|
half* __restrict__ x_new,
|
||||||
|
const int x_width,
|
||||||
|
const int x_height,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||||
|
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||||
|
|
||||||
|
int x_stride = x_width;
|
||||||
|
int x_idx = x_row * x_stride + x_column;
|
||||||
|
|
||||||
|
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||||
|
int x_idx_end = x_row_end * x_stride + x_column;
|
||||||
|
|
||||||
|
int s_column = x_map[x_column];
|
||||||
|
int s_idx = x_row * x_stride + s_column;
|
||||||
|
|
||||||
|
while (x_idx < x_idx_end)
|
||||||
|
{
|
||||||
|
x_new[x_idx] = x[s_idx];
|
||||||
|
x_idx += x_stride;
|
||||||
|
s_idx += x_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remap columns in x to correspond to sequential group index before matmul
|
||||||
|
//
|
||||||
|
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||||
|
|
||||||
|
void column_remap_cuda
|
||||||
|
(
|
||||||
|
const half* x,
|
||||||
|
half* x_new,
|
||||||
|
const int x_height,
|
||||||
|
const int x_width,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||||
|
|
||||||
|
dim3 blocks
|
||||||
|
(
|
||||||
|
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||||
|
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
|
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
#ifndef _column_remap_cuh
|
||||||
|
#define _column_remap_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
void column_remap_cuda
|
||||||
|
(
|
||||||
|
const half* x,
|
||||||
|
half* x_new,
|
||||||
|
const int x_height,
|
||||||
|
const int x_width,
|
||||||
|
const uint32_t* x_map
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,252 @@
|
|||||||
|
#include "q4_matmul.cuh"
|
||||||
|
#include "column_remap.cuh"
|
||||||
|
#include "../util.cuh"
|
||||||
|
#include "../matrix.cuh"
|
||||||
|
#include "../cuda_compat.cuh"
|
||||||
|
#include "../cuda_buffers.cuh"
|
||||||
|
|
||||||
|
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||||
|
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||||
|
|
||||||
|
typedef void (*fp_q4_matmul_kernel)
|
||||||
|
(
|
||||||
|
const half*,
|
||||||
|
const uint32_t*,
|
||||||
|
half*,
|
||||||
|
const half*,
|
||||||
|
const uint32_t*,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const uint32_t*,
|
||||||
|
bool
|
||||||
|
);
|
||||||
|
|
||||||
|
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||||
|
__global__ void q4_matmul_kernel
|
||||||
|
(
|
||||||
|
const half* __restrict__ x,
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
half* __restrict__ out,
|
||||||
|
const half* __restrict__ w_scales,
|
||||||
|
const uint32_t* __restrict__ w_zeros,
|
||||||
|
const int height,
|
||||||
|
const int dim,
|
||||||
|
const int width,
|
||||||
|
const int groupsize,
|
||||||
|
const int block_size_z,
|
||||||
|
const uint32_t* __restrict__ x_map,
|
||||||
|
bool no_zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Start of block
|
||||||
|
|
||||||
|
int x_column = block_size_z * blockIdx.z;
|
||||||
|
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||||
|
|
||||||
|
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
|
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||||
|
|
||||||
|
int iterations = (x_column_end - x_column) / 8;
|
||||||
|
|
||||||
|
// Views
|
||||||
|
|
||||||
|
MatrixView_half x_(x, height, dim);
|
||||||
|
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||||
|
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||||
|
MatrixView_q4_column w_(w, dim, width);
|
||||||
|
MatrixView_half_rw out_(out, height, width);
|
||||||
|
|
||||||
|
// Zero output
|
||||||
|
|
||||||
|
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||||
|
{
|
||||||
|
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over part of x row (and w column)
|
||||||
|
|
||||||
|
half2 acc = {};
|
||||||
|
half acc_h = {};
|
||||||
|
|
||||||
|
if constexpr (use_groupsize)
|
||||||
|
{
|
||||||
|
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||||
|
// could be slightly faster
|
||||||
|
|
||||||
|
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||||
|
{
|
||||||
|
if constexpr (use_half2)
|
||||||
|
{
|
||||||
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||||
|
|
||||||
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||||
|
|
||||||
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||||
|
|
||||||
|
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||||
|
{
|
||||||
|
if constexpr (use_half2)
|
||||||
|
{
|
||||||
|
int group = k / groupsize;
|
||||||
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||||
|
|
||||||
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
int group = k / groupsize;
|
||||||
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||||
|
|
||||||
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to block result
|
||||||
|
|
||||||
|
if constexpr (use_half2)
|
||||||
|
{
|
||||||
|
half result = __hadd(acc.x, acc.y);
|
||||||
|
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||||
|
{
|
||||||
|
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||||
|
if (tuningParams->matmul_no_half2) {
|
||||||
|
if (block_size_z % groupsize == 0) {
|
||||||
|
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||||
|
else return q4_matmul_kernel<false, true, false>;
|
||||||
|
} else {
|
||||||
|
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||||
|
else return q4_matmul_kernel<false, false, false>;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (block_size_z % groupsize == 0)
|
||||||
|
{
|
||||||
|
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||||
|
else return q4_matmul_kernel<true, true, false>;
|
||||||
|
} else {
|
||||||
|
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||||
|
else return q4_matmul_kernel<true, false, false>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute y = x @ w
|
||||||
|
|
||||||
|
void q4_matmul_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
const Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
bool no_zero,
|
||||||
|
cudaStream_t alt_stream
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int height = x_height;
|
||||||
|
int dim = w->height;
|
||||||
|
int width = w->width;
|
||||||
|
|
||||||
|
cudaSetDevice(w->device);
|
||||||
|
|
||||||
|
uint32_t* x_map = w->cuda_x_map;
|
||||||
|
const half* x_mapped = x;
|
||||||
|
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||||
|
{
|
||||||
|
CudaBuffers* buffers = get_buffers(w->device);
|
||||||
|
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||||
|
x_mapped = buffers->temp_state;
|
||||||
|
x_map = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
int block_size_z;
|
||||||
|
if (w->width == 4096) block_size_z = 384; // 7B
|
||||||
|
else if (w->width == 11008) block_size_z = 256;
|
||||||
|
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||||
|
else if (w->width == 13824) block_size_z = 256;
|
||||||
|
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||||
|
else if (w->width == 17920) block_size_z = 128;
|
||||||
|
else block_size_z = 256;
|
||||||
|
|
||||||
|
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||||
|
|
||||||
|
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||||
|
|
||||||
|
dim3 blocks
|
||||||
|
(
|
||||||
|
(width + threads.x - 1) / threads.x,
|
||||||
|
(height + threads.y - 1) / threads.y,
|
||||||
|
(dim + block_size_z - 1) / block_size_z
|
||||||
|
);
|
||||||
|
|
||||||
|
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||||
|
|
||||||
|
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||||
|
}
|
||||||
|
|
||||||
|
void q4_matmul_recons_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
const cublasHandle_t handle,
|
||||||
|
bool no_zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int height = x_height;
|
||||||
|
int dim = w->height;
|
||||||
|
int width = w->width;
|
||||||
|
|
||||||
|
cudaSetDevice(w->device);
|
||||||
|
CudaBuffers* buffers = get_buffers(w->device);
|
||||||
|
|
||||||
|
const half* x_mapped = x;
|
||||||
|
if (w->cuda_x_map)
|
||||||
|
{
|
||||||
|
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||||
|
x_mapped = buffers->temp_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
w->reconstruct(buffers->temp_dq);
|
||||||
|
|
||||||
|
const half alpha = __float2half(1.0f);
|
||||||
|
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||||
|
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||||
|
|
||||||
|
// const float alpha = 1.0f;
|
||||||
|
// const float beta = no_zero ? 1.0f : 0.0f;
|
||||||
|
// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||||
|
// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
#ifndef _q4_matmul_cuh
|
||||||
|
#define _q4_matmul_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "q4_matrix.cuh"
|
||||||
|
#include "../tuning.h"
|
||||||
|
|
||||||
|
void q4_matmul_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
const Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
bool no_zero = false,
|
||||||
|
cudaStream_t alt_stream = NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
void q4_matmul_recons_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
const cublasHandle_t handle,
|
||||||
|
bool no_zero = false
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,215 @@
|
|||||||
|
#include "q4_matrix.cuh"
|
||||||
|
#include <vector>
|
||||||
|
#include "../util.cuh"
|
||||||
|
#include "../matrix.cuh"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||||
|
|
||||||
|
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||||
|
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||||
|
|
||||||
|
vector<Q4Matrix*> g_q4_matrices;
|
||||||
|
|
||||||
|
void g_q4_keep_matrix(Q4Matrix* m)
|
||||||
|
{
|
||||||
|
g_q4_matrices.push_back(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void g_q4_free_matrices()
|
||||||
|
{
|
||||||
|
for (const auto& m : g_q4_matrices) delete m;
|
||||||
|
g_q4_matrices.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
Q4Matrix::Q4Matrix
|
||||||
|
(
|
||||||
|
const int _height,
|
||||||
|
const int _width,
|
||||||
|
const int _groups,
|
||||||
|
|
||||||
|
uint32_t* _qweight,
|
||||||
|
uint32_t* _qzeros,
|
||||||
|
half* _scales,
|
||||||
|
uint32_t* _g_idx,
|
||||||
|
|
||||||
|
const int _device
|
||||||
|
) :
|
||||||
|
height(_height),
|
||||||
|
width(_width),
|
||||||
|
groups(_groups),
|
||||||
|
device(_device)
|
||||||
|
{
|
||||||
|
cudaSetDevice(device);
|
||||||
|
|
||||||
|
cuda_qweight = _qweight;
|
||||||
|
cuda_qzeros = _qzeros;
|
||||||
|
cuda_scales = _scales;
|
||||||
|
|
||||||
|
groupsize = height / groups;
|
||||||
|
|
||||||
|
if (_g_idx) make_sequential(_g_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Q4Matrix::~Q4Matrix()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sequential
|
||||||
|
|
||||||
|
__global__ void make_sequential_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
uint32_t* __restrict__ w_new,
|
||||||
|
const uint32_t* __restrict__ x_map,
|
||||||
|
const int w_height,
|
||||||
|
const int w_width
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint64_t* w2 = (uint64_t*) w;
|
||||||
|
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||||
|
int w2_stride = w_width >> 1;
|
||||||
|
|
||||||
|
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||||
|
int w_new2_row = blockIdx.y;
|
||||||
|
|
||||||
|
int x_map_idx = w_new2_row << 3;
|
||||||
|
|
||||||
|
uint64_t dst = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
int source_row = x_map[x_map_idx++];
|
||||||
|
|
||||||
|
int w2_row = source_row >> 3;
|
||||||
|
int w2_subrow = source_row & 0x07;
|
||||||
|
int w2_row_shift = w2_subrow << 2;
|
||||||
|
int wnew2_row_shift = i << 2;
|
||||||
|
|
||||||
|
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||||
|
src >>= w2_row_shift;
|
||||||
|
src &= 0x0000000f0000000f;
|
||||||
|
src <<= wnew2_row_shift;
|
||||||
|
dst |= src;
|
||||||
|
}
|
||||||
|
|
||||||
|
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
|
{
|
||||||
|
uint32_t* cuda_new_qweight = NULL;
|
||||||
|
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
|
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||||
|
|
||||||
|
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||||
|
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||||
|
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||||
|
|
||||||
|
// Group histogram
|
||||||
|
|
||||||
|
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||||
|
|
||||||
|
// Group map
|
||||||
|
|
||||||
|
for (int i = 0, acc = 0; i < groups; i++)
|
||||||
|
{
|
||||||
|
short tmp = cpu_g_idx_map[i];
|
||||||
|
cpu_g_idx_map[i] = acc;
|
||||||
|
acc += tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// X map (inverse)
|
||||||
|
|
||||||
|
for (int row = 0; row < height; row++)
|
||||||
|
{
|
||||||
|
uint32_t target_group = cpu_g_idx[row];
|
||||||
|
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||||
|
cpu_g_idx_map[target_group]++;
|
||||||
|
cpu_x_map_inv[row] = target_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
// X map
|
||||||
|
|
||||||
|
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||||
|
|
||||||
|
// Move to CUDA
|
||||||
|
|
||||||
|
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||||
|
|
||||||
|
// Rearrange rows in w
|
||||||
|
|
||||||
|
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||||
|
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
||||||
|
|
||||||
|
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||||
|
|
||||||
|
// Replace qweights
|
||||||
|
|
||||||
|
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
cudaFree(cuda_new_qweight);
|
||||||
|
free(cpu_g_idx_map);
|
||||||
|
free(cpu_x_map);
|
||||||
|
free(cpu_x_map_inv);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void reconstruct_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
half* __restrict__ out, // (y)
|
||||||
|
const half* __restrict__ w_scales,
|
||||||
|
const uint32_t* __restrict__ w_zeros,
|
||||||
|
const int height,
|
||||||
|
const int width,
|
||||||
|
const int groupsize
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Start of block
|
||||||
|
|
||||||
|
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
|
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||||
|
|
||||||
|
// Views
|
||||||
|
|
||||||
|
MatrixView_q4_column w_(w, height, width);
|
||||||
|
MatrixView_half_rw out_(out, height, width);
|
||||||
|
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||||
|
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||||
|
|
||||||
|
// Groupsize version
|
||||||
|
|
||||||
|
int group = row / groupsize;
|
||||||
|
|
||||||
|
half w_scale = w_scales_.item(group, column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||||
|
|
||||||
|
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||||
|
half* out_ptr = out_.item_ptr(row, column);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int s = 0; s < 32; s += 4)
|
||||||
|
{
|
||||||
|
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||||
|
*out_ptr = w_item; out_ptr += out_.width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Q4Matrix::reconstruct(half* out)
|
||||||
|
{
|
||||||
|
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||||
|
|
||||||
|
dim3 blocks
|
||||||
|
(
|
||||||
|
(width + threads.x - 1) / threads.x,
|
||||||
|
(height / 8 + threads.y - 1) / threads.y,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
|
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
#ifndef _q4_matrix_cuh
|
||||||
|
#define _q4_matrix_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
class Q4Matrix
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
int device;
|
||||||
|
|
||||||
|
int height;
|
||||||
|
int width;
|
||||||
|
int groups;
|
||||||
|
int groupsize;
|
||||||
|
|
||||||
|
uint32_t* cuda_qweight = NULL;
|
||||||
|
uint32_t* cuda_qzeros = NULL;
|
||||||
|
half* cuda_scales = NULL;
|
||||||
|
uint32_t* cuda_x_map = NULL;
|
||||||
|
|
||||||
|
Q4Matrix
|
||||||
|
(
|
||||||
|
const int _height,
|
||||||
|
const int _width,
|
||||||
|
const int _groups,
|
||||||
|
|
||||||
|
uint32_t* _qweight,
|
||||||
|
uint32_t* _qzeros,
|
||||||
|
half* _scales,
|
||||||
|
uint32_t* _g_idx,
|
||||||
|
|
||||||
|
const int _device
|
||||||
|
);
|
||||||
|
|
||||||
|
~Q4Matrix();
|
||||||
|
|
||||||
|
void reconstruct(half* out);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void make_sequential(const uint32_t* cpu_g_idx);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
void g_q4_keep_matrix(Q4Matrix* m);
|
||||||
|
void g_q4_free_matrices();
|
||||||
|
|
||||||
|
#endif
|
247
server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp
Normal file
247
server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include "util.cuh"
|
||||||
|
#include "tuning.h"
|
||||||
|
#include "cuda_buffers.cuh"
|
||||||
|
#include "cuda_func/q4_matrix.cuh"
|
||||||
|
#include "cuda_func/q4_matmul.cuh"
|
||||||
|
#include "cuda_func/column_remap.cuh"
|
||||||
|
|
||||||
|
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||||
|
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||||
|
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||||
|
|
||||||
|
void check_cuda(cudaError_t ret)
|
||||||
|
{
|
||||||
|
switch (ret)
|
||||||
|
{
|
||||||
|
case cudaSuccess:
|
||||||
|
break;
|
||||||
|
|
||||||
|
case cudaUnspecified:
|
||||||
|
printf(" **** Unspecified error\n");
|
||||||
|
TORCH_CHECK(false, "CUDA error");
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
printf(" **** CUDA error\n"); \
|
||||||
|
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||||
|
TORCH_CHECK(false, "CUDA error"); \
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some decluttering macros
|
||||||
|
|
||||||
|
#define STRINGIFY_(__x) #__x
|
||||||
|
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||||
|
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||||
|
|
||||||
|
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||||
|
do { \
|
||||||
|
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||||
|
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||||
|
do { \
|
||||||
|
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||||
|
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||||
|
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||||
|
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||||
|
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||||
|
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||||
|
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||||
|
{
|
||||||
|
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||||
|
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||||
|
return groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Tuning parameters
|
||||||
|
|
||||||
|
ExLlamaTuning tuningParams;
|
||||||
|
|
||||||
|
void set_tuning_params
|
||||||
|
(
|
||||||
|
int matmul_recons_thd,
|
||||||
|
bool matmul_fused_remap,
|
||||||
|
bool matmul_no_half2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||||
|
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||||
|
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Release all unmanaged objects allocated by the extension
|
||||||
|
|
||||||
|
void cleanup()
|
||||||
|
{
|
||||||
|
cleanup_buffers_cuda();
|
||||||
|
g_q4_free_matrices();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Prepare buffers for forward pass
|
||||||
|
|
||||||
|
void prepare_buffers
|
||||||
|
(
|
||||||
|
torch::Device device,
|
||||||
|
torch::Tensor temp_state,
|
||||||
|
torch::Tensor temp_dq
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int device_index = device.index();
|
||||||
|
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||||
|
|
||||||
|
prepare_buffers_cuda
|
||||||
|
(
|
||||||
|
device_index,
|
||||||
|
(half*) temp_state.data_ptr(),
|
||||||
|
(half*) temp_dq.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Create Q4Matrix, return handle
|
||||||
|
|
||||||
|
uintptr_t make_q4
|
||||||
|
(
|
||||||
|
torch::Tensor qweight,
|
||||||
|
torch::Tensor qzeros,
|
||||||
|
torch::Tensor scales,
|
||||||
|
torch::Tensor g_idx,
|
||||||
|
int device
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||||
|
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||||
|
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||||
|
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||||
|
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||||
|
|
||||||
|
int width = qweight.size(1);
|
||||||
|
int height = qweight.size(0) * 8;
|
||||||
|
int groups = qzeros.size(0);
|
||||||
|
|
||||||
|
Q4Matrix* m = new Q4Matrix
|
||||||
|
(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
|
||||||
|
(uint32_t*) qweight.data_ptr(),
|
||||||
|
(uint32_t*) qzeros.data_ptr(),
|
||||||
|
(half*) scales.data_ptr(),
|
||||||
|
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||||
|
|
||||||
|
device
|
||||||
|
);
|
||||||
|
|
||||||
|
g_q4_keep_matrix(m);
|
||||||
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Matmul half @ quant -> half
|
||||||
|
|
||||||
|
void q4_matmul
|
||||||
|
(
|
||||||
|
torch::Tensor x,
|
||||||
|
uintptr_t w,
|
||||||
|
torch::Tensor out
|
||||||
|
)
|
||||||
|
{
|
||||||
|
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||||
|
|
||||||
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(out, kHalf);
|
||||||
|
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||||
|
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
|
|
||||||
|
int x_height = x.size(0);
|
||||||
|
|
||||||
|
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||||
|
{
|
||||||
|
q4_matmul_cuda
|
||||||
|
(
|
||||||
|
&tuningParams,
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
x_height,
|
||||||
|
wm,
|
||||||
|
(half*) out.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
q4_matmul_recons_cuda
|
||||||
|
(
|
||||||
|
&tuningParams,
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
x_height,
|
||||||
|
wm,
|
||||||
|
(half*) out.data_ptr(),
|
||||||
|
at::cuda::getCurrentCUDABlasHandle()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Remap columns in half tensor
|
||||||
|
|
||||||
|
void column_remap
|
||||||
|
(
|
||||||
|
torch::Tensor x,
|
||||||
|
torch::Tensor x_new,
|
||||||
|
torch::Tensor x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||||
|
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||||
|
|
||||||
|
int height = x.size(0);
|
||||||
|
int width = x.size(1);
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
|
|
||||||
|
column_remap_cuda
|
||||||
|
(
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
(half*) x_new.data_ptr(),
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
(uint32_t*) x_map.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
|
{
|
||||||
|
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||||
|
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||||
|
m.def("cleanup", &cleanup, "cleanup");
|
||||||
|
m.def("make_q4", &make_q4, "make_q4");
|
||||||
|
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||||
|
}
|
292
server/custom_kernels/custom_kernels/exllama/matrix.cuh
Normal file
292
server/custom_kernels/custom_kernels/exllama/matrix.cuh
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
#ifndef _matrix_cuh
|
||||||
|
#define _matrix_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
class MatrixView_half
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_half_rw
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_column
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (row & 0x07) * 4;
|
||||||
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||||
|
|
||||||
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||||
|
|
||||||
|
__device__ __forceinline__ half2 dot_product_8
|
||||||
|
(
|
||||||
|
const half2 acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half2 v_scale_2,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half2 result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half2 v_01 = __halves2half2(v_0, v_1);
|
||||||
|
half2 v_23 = __halves2half2(v_2, v_3);
|
||||||
|
half2 v_45 = __halves2half2(v_4, v_5);
|
||||||
|
half2 v_67 = __halves2half2(v_6, v_7);
|
||||||
|
|
||||||
|
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||||
|
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||||
|
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||||
|
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||||
|
|
||||||
|
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||||
|
result = __hfma2(v_scale_2, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ half dot_product_8_h
|
||||||
|
(
|
||||||
|
const half acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half v_scale,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half tmp = __hmul(*h_ptr++, v_0);
|
||||||
|
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||||
|
result = __hfma(v_scale, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||||
|
|
||||||
|
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||||
|
(
|
||||||
|
const half2 acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half2 v_scale_2,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||||
|
const uint32_t* x_map_ptr = x_map + h_column;
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half2 result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half2 v_01 = __halves2half2(v_0, v_1);
|
||||||
|
half2 v_23 = __halves2half2(v_2, v_3);
|
||||||
|
half2 v_45 = __halves2half2(v_4, v_5);
|
||||||
|
half2 v_67 = __halves2half2(v_6, v_7);
|
||||||
|
|
||||||
|
half h_0 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_1 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_2 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_3 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_4 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_5 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_6 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_7 = h_ptr[*x_map_ptr++];
|
||||||
|
|
||||||
|
half2 h_01 = __halves2half2(h_0, h_1);
|
||||||
|
half2 h_23 = __halves2half2(h_2, h_3);
|
||||||
|
half2 h_45 = __halves2half2(h_4, h_5);
|
||||||
|
half2 h_67 = __halves2half2(h_6, h_7);
|
||||||
|
|
||||||
|
half2 tmp = __hmul2(h_01, v_01);
|
||||||
|
tmp = __hfma2(h_23, v_23, tmp);
|
||||||
|
tmp = __hfma2(h_45, v_45, tmp);
|
||||||
|
tmp = __hfma2(h_67, v_67, tmp);
|
||||||
|
result = __hfma2(v_scale_2, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||||
|
(
|
||||||
|
const half acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half v_scale,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||||
|
const uint32_t* x_map_ptr = x_map + h_column;
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||||
|
result = __hfma(v_scale, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
11
server/custom_kernels/custom_kernels/exllama/tuning.h
Normal file
11
server/custom_kernels/custom_kernels/exllama/tuning.h
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#ifndef _tuning_h
|
||||||
|
#define _tuning_h
|
||||||
|
|
||||||
|
struct ExLlamaTuning
|
||||||
|
{
|
||||||
|
int matmul_recons_thd;
|
||||||
|
bool matmul_fused_remap;
|
||||||
|
bool matmul_no_half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
27
server/custom_kernels/custom_kernels/exllama/util.cuh
Normal file
27
server/custom_kernels/custom_kernels/exllama/util.cuh
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#ifndef _util_cuh
|
||||||
|
#define _util_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#define cudaUnspecified cudaErrorApiFailureBase
|
||||||
|
|
||||||
|
// React to failure on return code != cudaSuccess
|
||||||
|
|
||||||
|
#define _cuda_check(fn) \
|
||||||
|
do { \
|
||||||
|
{_cuda_err = fn;} \
|
||||||
|
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||||
|
} while(false)
|
||||||
|
|
||||||
|
// React to failure on return code == 0
|
||||||
|
|
||||||
|
#define _alloc_check(fn) \
|
||||||
|
do { \
|
||||||
|
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||||
|
else _cuda_err = cudaSuccess; \
|
||||||
|
} while(false)
|
||||||
|
|
||||||
|
#endif
|
@ -1,5 +1,5 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="custom_kernels",
|
name="custom_kernels",
|
||||||
@ -14,6 +14,16 @@ setup(
|
|||||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||||
),
|
),
|
||||||
|
CppExtension(
|
||||||
|
name="custom_kernels.exllama",
|
||||||
|
sources=[
|
||||||
|
"custom_kernels/exllama/exllama_ext.cpp",
|
||||||
|
"custom_kernels/exllama/cuda_buffers.cu",
|
||||||
|
"custom_kernels/exllama/cuda_func/column_remap.cu",
|
||||||
|
"custom_kernels/exllama/cuda_func/q4_matmul.cu",
|
||||||
|
"custom_kernels/exllama/cuda_func/q4_matrix.cu"
|
||||||
|
],
|
||||||
|
)
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
)
|
)
|
||||||
|
@ -14,6 +14,7 @@ app = typer.Typer()
|
|||||||
class Quantization(str, Enum):
|
class Quantization(str, Enum):
|
||||||
bitsandbytes = "bitsandbytes"
|
bitsandbytes = "bitsandbytes"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
|
gptq_cuda = "gptq-cuda"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -246,7 +246,7 @@ def get_model(
|
|||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
if quantize == "gptq":
|
if quantize in ["gptq", "gptq-cuda"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,6 @@ from typing import Optional
|
|||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -17,15 +16,17 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||||
|
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
if config.quantize in ["gptq", "gptq-cuda"]:
|
||||||
if config.quantize == "gptq":
|
layer = _load_multi_mqa_gptq(
|
||||||
return _load_multi_mqa_gptq(
|
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
)
|
)
|
||||||
|
return layer
|
||||||
else:
|
else:
|
||||||
return _load_multi_mqa(
|
return _load_multi_mqa(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
@ -87,7 +88,7 @@ def _load_multi_mqa_gptq(
|
|||||||
kv_tensor = slice_[-2 * head_size :]
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||||
|
|
||||||
@ -95,7 +96,6 @@ def _load_multi_mqa_gptq(
|
|||||||
def _load_multi_mqa(
|
def _load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -160,7 +160,7 @@ def _load_multi_mqa(
|
|||||||
assert list(bias.shape) == [
|
assert list(bias.shape) == [
|
||||||
(num_heads + 2) * head_size
|
(num_heads + 2) * head_size
|
||||||
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
||||||
|
|
||||||
|
|
||||||
def load_col(config, prefix: str, weights, bias: bool):
|
def load_col(config, prefix: str, weights, bias: bool):
|
||||||
@ -175,22 +175,40 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
quantize = config.quantize
|
||||||
|
if quantize == "gptq-cuda" and weights.process_group.size() > 1:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
|
act_order = True
|
||||||
|
if g_idx is not None:
|
||||||
|
if torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) or (g_idx == 0).all():
|
||||||
|
act_order = False
|
||||||
|
else:
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
quantize = "gptq"
|
||||||
|
|
||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_multi_weights_row(prefix, quantize=quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
|
|
||||||
|
if quantize == "gptq-cuda" and not act_order:
|
||||||
|
weight[3] = None # remove g_idx to indicate to exllama that act-order is not used
|
||||||
|
|
||||||
return TensorParallelRowLinear(
|
return TensorParallelRowLinear(
|
||||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
get_linear(weight, bias, quantize, device=weights.device), process_group=weights.process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -495,6 +513,30 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Buffers need to be persistent to avoid any bug.
|
||||||
|
self.buffers = {}
|
||||||
|
if config.quantize == "gptq-cuda":
|
||||||
|
max_dq_buffer_size = 0
|
||||||
|
for name, submodule in self.named_modules():
|
||||||
|
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||||
|
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||||
|
|
||||||
|
intermediate_size = config.n_inner
|
||||||
|
max_seq_len = 2048 # TODO: we should be able to set it
|
||||||
|
|
||||||
|
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device)
|
||||||
|
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device)
|
||||||
|
|
||||||
|
prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||||
|
|
||||||
|
# TODO: ability to set them
|
||||||
|
matmul_recons_thd = 8
|
||||||
|
matmul_fused_remap = False
|
||||||
|
matmul_no_half2 = False
|
||||||
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -59,6 +59,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -252,6 +252,7 @@ class QuantLinear(nn.Module):
|
|||||||
self.register_buffer("qzeros", qzeros)
|
self.register_buffer("qzeros", qzeros)
|
||||||
self.register_buffer("scales", scales)
|
self.register_buffer("scales", scales)
|
||||||
self.register_buffer("g_idx", g_idx)
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.register_buffer("bias", bias)
|
||||||
else:
|
else:
|
||||||
@ -357,3 +358,82 @@ class QuantLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from custom_kernels.exllama import make_q4, q4_matmul
|
||||||
|
|
||||||
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
none_tensor = torch.empty((1, 1), device = "meta")
|
||||||
|
|
||||||
|
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||||
|
"""Construct Q4Matrix, return handle"""
|
||||||
|
return make_q4(qweight,
|
||||||
|
qzeros,
|
||||||
|
scales,
|
||||||
|
g_idx if g_idx is not None else none_tensor,
|
||||||
|
device)
|
||||||
|
|
||||||
|
def ext_q4_matmul(x, q4, q4_width):
|
||||||
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
|
outshape = x.shape[:-1] + (q4_width,)
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||||
|
|
||||||
|
q4_matmul(x, q4, output)
|
||||||
|
|
||||||
|
return output.view(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
class Ex4bitLinear:
|
||||||
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size: int):
|
||||||
|
assert bits == 4
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.qweight = qweight.to(device)
|
||||||
|
self.qzeros = qzeros.to(device)
|
||||||
|
self.scales = scales.to(device)
|
||||||
|
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||||
|
self.bias = bias.to(device) if bias is not None else None
|
||||||
|
|
||||||
|
if self.g_idx is not None and (self.g_idx == 0).all():
|
||||||
|
self.empty_g_idx = True
|
||||||
|
self.g_idx = None
|
||||||
|
|
||||||
|
assert device.type == "cuda"
|
||||||
|
assert device.index is not None
|
||||||
|
|
||||||
|
self.q4 = ext_make_q4(
|
||||||
|
self.qweight,
|
||||||
|
self.qzeros,
|
||||||
|
self.scales,
|
||||||
|
self.g_idx,
|
||||||
|
device.index
|
||||||
|
)
|
||||||
|
|
||||||
|
self.height = qweight.shape[0] * 8
|
||||||
|
self.width = qweight.shape[1]
|
||||||
|
|
||||||
|
# Infer groupsize from height of qzeros
|
||||||
|
self.groupsize = None
|
||||||
|
if self.qzeros.shape[0] > 1:
|
||||||
|
if world_size is None:
|
||||||
|
world_size = 1
|
||||||
|
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0] // world_size)
|
||||||
|
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||||
|
|
||||||
|
assert groupsize == self.groupsize
|
||||||
|
|
||||||
|
# Handle act-order matrix
|
||||||
|
if self.g_idx is not None:
|
||||||
|
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||||
|
self.act_order = True
|
||||||
|
else:
|
||||||
|
self.act_order = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = ext_q4_matmul(x, self.q4, self.width)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
out.add_(self.bias)
|
||||||
|
return out
|
||||||
|
@ -15,8 +15,9 @@ except ImportError:
|
|||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -118,7 +119,7 @@ class Linear8bitLt(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize, device = None):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
@ -147,6 +148,15 @@ def get_linear(weight, bias, quantize):
|
|||||||
bits,
|
bits,
|
||||||
groupsize,
|
groupsize,
|
||||||
)
|
)
|
||||||
|
elif quantize == "gptq-cuda":
|
||||||
|
try:
|
||||||
|
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||||
|
except Exception:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
|
)
|
||||||
|
|
||||||
|
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
@ -171,12 +181,12 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize in ["gptq", "gptq-cuda"]:
|
||||||
quantize = None
|
quantize = None
|
||||||
else:
|
else:
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None, quantize=quantize),
|
get_linear(weight, bias=None, quantize=quantize, device=weights.device),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -232,7 +242,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
bias = torch.cat(b, dim=dim)
|
bias = torch.cat(b, dim=dim)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize, device=weights.device)
|
||||||
return cls(linear)
|
return cls(linear)
|
||||||
|
|
||||||
|
|
||||||
@ -251,7 +261,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return cls(
|
return cls(
|
||||||
get_linear(weight, bias, config.quantize),
|
get_linear(weight, bias, config.quantize, device=weights.device),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ from typing import List, Dict, Optional
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||||
routing = {}
|
routing = {}
|
||||||
@ -43,7 +42,7 @@ class Weights:
|
|||||||
return str(filename), tensor_name
|
return str(filename), tensor_name
|
||||||
|
|
||||||
def _get_slice(self, tensor_name: str):
|
def _get_slice(self, tensor_name: str):
|
||||||
filename, tensor_name= self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
@ -92,7 +91,7 @@ class Weights:
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize in ["gptq", "gptq-cuda"]:
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -107,26 +106,39 @@ class Weights:
|
|||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize in ["gptq", "gptq-cuda"]:
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
if quantize == "gptq":
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
else:
|
||||||
|
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||||
|
# the scales and zero-points do not need to be reordered
|
||||||
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
# For tp > 1, at this point we know we do not use act-order
|
||||||
|
if self.process_group.size() == 1:
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user