Fixing exllamav2.

This commit is contained in:
Ubuntu 2023-11-23 11:24:41 +00:00 committed by Nicolas Patry
parent fb64ce1040
commit a61f432599
10 changed files with 154 additions and 51 deletions

View File

@ -0,0 +1,38 @@
#ifndef _compat_gemm_cuh
#define _compat_gemm_cuh
#if defined(USE_ROCM)
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
// for symbols as hipblasHalf.
#include <hipblas/hipblas.h>
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
#endif

View File

@ -18,34 +18,7 @@
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
#include "compat_gemm.cuh"
void gemm_half_q_half_cuda_part
(

View File

@ -1,5 +1,8 @@
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};

View File

@ -72,6 +72,8 @@ QMatrix::QMatrix
{
cudaSetDevice(device);
failed = false;
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
@ -125,7 +127,15 @@ QMatrix::QMatrix
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
if (_gptq_g_idx)
{
if (!make_sequential(_gptq_g_idx))
{
failed = true;
//printf("FAIL\n");
return;
}
}
}
// Shuffle quantized data
@ -139,6 +149,9 @@ QMatrix::QMatrix
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()
{
}
// Reconstruct b[k,n] (GPTQ)
@ -437,11 +450,11 @@ void QMatrix::reconstruct(half* out)
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
@ -464,6 +477,7 @@ void QMatrix::reconstruct(half* out)
}
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
@ -523,10 +537,14 @@ __global__ void make_sequential_kernel
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) {
cudaError_t cuda_status = cudaGetLastError(); // Clear error
return false;
}
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
@ -600,4 +618,6 @@ void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
return true;
}

View File

@ -38,6 +38,8 @@ public:
half* temp_dq;
bool failed;
QMatrix
(
const int _device,
@ -62,7 +64,7 @@ public:
~QMatrix();
void reconstruct(half* out);
void make_sequential(const uint32_t* cpu_g_idx);
bool make_sequential(const uint32_t* cpu_g_idx);
private:

View File

@ -30,3 +30,13 @@ __forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}

View File

@ -168,7 +168,7 @@ def serve(
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import (
from text_generation_server.utils.layers import (
create_exllama_buffers,
set_device,
)

View File

@ -17,10 +17,6 @@ except ImportError:
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def _torch_device(idx):
if idx == -1: return "cpu"
return f"cuda:{idx}"
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
@ -82,6 +78,53 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor,
temp_dq)
DEVICE = None
FIXED_BYTES = 0
LAYERS = []
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers():
global FIXED_BYTES, LAYERS, DEVICE
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
for layer in LAYERS:
layer.post_init(temp_dq)
# assert DEVICE is not None, "call set_device first"
# if ACT_ORDER:
# # TODO: this should be set to rust side `max_total_tokens`, but TGI
# # does not offer an API to expose this variable to python, as this variable
# # is handled by the client but it appears the model is initialized by the server.
# # An alternative could be to initialize the buffers during warmup.
# # Dummy
# max_total_tokens = 2048
# else:
# max_total_tokens = 1
# # This temp_state buffer is required to reorder X in the act-order case.
# temp_state = torch.zeros(
# (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
# )
# temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
# prepare_buffers(DEVICE, temp_state, temp_dq)
# matmul_recons_thd = 8
# matmul_fused_remap = False
# matmul_no_half2 = False
# set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
# TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
@ -98,7 +141,6 @@ class QuantLinear(nn.Module):
self.q_handle = None
self.q_tensors = None
# self.padding = - outfeatures % 32
#
# self.infeatures = infeatures
# self.outfeatures = outfeatures + self.padding
@ -112,7 +154,8 @@ class QuantLinear(nn.Module):
# assert infeatures % 32 == 0
# assert infeatures % self.group_size == 0
# assert outfeatures % 32 == 0
#
# self.padding = - outfeatures % 32
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
# self.register_buffer(
# 'qweight',
@ -137,13 +180,16 @@ class QuantLinear(nn.Module):
self.g_idx = g_idx
self.bias = bias if bias is not None else None
global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
LAYERS.append(self)
# if bias:
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
# else:
# self.bias = None
# def post_init(self, temp_dq):
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8))
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
@ -152,7 +198,7 @@ class QuantLinear(nn.Module):
"scales":self.scales,
"g_idx":self.g_idx
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8))
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq
)
@ -181,12 +227,12 @@ class ExLlamaV2DeviceTensors:
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device_idx, scratch_bytes):
self.device_idx = device_idx
def __init__(self, device, scratch_bytes):
self.device = device
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
def get_scratch_slice(self, size_bytes):

View File

@ -31,16 +31,27 @@ try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
from text_generation_server.utils.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear
if V2:
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
HAS_EXLLAMA = True
except ImportError:
pass
@ -309,7 +320,7 @@ def get_linear(weight, bias, quantize):
)
if use_exllama:
linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
linear = QuantLinear(
qweight,

View File

@ -278,7 +278,7 @@ class Weights:
)
use_exllama = False
else:
logger.info("Using exllama kernels")
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama:
if groupsize >= 0: