mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing exllamav2.
This commit is contained in:
parent
fb64ce1040
commit
a61f432599
@ -0,0 +1,38 @@
|
||||
#ifndef _compat_gemm_cuh
|
||||
#define _compat_gemm_cuh
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
|
||||
// for symbols as hipblasHalf.
|
||||
#include <hipblas/hipblas.h>
|
||||
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
#endif
|
@ -18,34 +18,7 @@
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
#include "compat_gemm.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
|
@ -72,6 +72,8 @@ QMatrix::QMatrix
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
failed = false;
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
@ -125,7 +127,15 @@ QMatrix::QMatrix
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
||||
if (_gptq_g_idx)
|
||||
{
|
||||
if (!make_sequential(_gptq_g_idx))
|
||||
{
|
||||
failed = true;
|
||||
//printf("FAIL\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle quantized data
|
||||
@ -139,6 +149,9 @@ QMatrix::QMatrix
|
||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||
}
|
||||
|
||||
QMatrix::~QMatrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Reconstruct b[k,n] (GPTQ)
|
||||
|
||||
@ -437,11 +450,11 @@ void QMatrix::reconstruct(half* out)
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
@ -464,6 +477,7 @@ void QMatrix::reconstruct(half* out)
|
||||
}
|
||||
else
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
@ -523,10 +537,14 @@ __global__ void make_sequential_kernel
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
if (err != cudaSuccess) {
|
||||
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
@ -600,4 +618,6 @@ void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -38,6 +38,8 @@ public:
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
bool failed;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
@ -62,7 +64,7 @@ public:
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
bool make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
|
@ -30,3 +30,13 @@ __forceinline__ __device__ float clamp(float x, float a, float b)
|
||||
{
|
||||
return fmaxf(a, fminf(b, x));
|
||||
}
|
||||
|
||||
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
|
||||
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
|
||||
{
|
||||
if (code != cudaSuccess)
|
||||
{
|
||||
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
||||
|
@ -168,7 +168,7 @@ def serve(
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.gptq.exllama import (
|
||||
from text_generation_server.utils.layers import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
@ -17,10 +17,6 @@ except ImportError:
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
def _torch_device(idx):
|
||||
if idx == -1: return "cpu"
|
||||
return f"cuda:{idx}"
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
@ -82,6 +78,53 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
|
||||
DEVICE = None
|
||||
FIXED_BYTES = 0
|
||||
LAYERS = []
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
global FIXED_BYTES, LAYERS, DEVICE
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
||||
|
||||
for layer in LAYERS:
|
||||
layer.post_init(temp_dq)
|
||||
|
||||
|
||||
# assert DEVICE is not None, "call set_device first"
|
||||
|
||||
# if ACT_ORDER:
|
||||
# # TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# # does not offer an API to expose this variable to python, as this variable
|
||||
# # is handled by the client but it appears the model is initialized by the server.
|
||||
# # An alternative could be to initialize the buffers during warmup.
|
||||
# # Dummy
|
||||
# max_total_tokens = 2048
|
||||
# else:
|
||||
# max_total_tokens = 1
|
||||
|
||||
# # This temp_state buffer is required to reorder X in the act-order case.
|
||||
# temp_state = torch.zeros(
|
||||
# (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
||||
# )
|
||||
# temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||
|
||||
# # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
# prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
# matmul_recons_thd = 8
|
||||
# matmul_fused_remap = False
|
||||
# matmul_no_half2 = False
|
||||
# set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
# TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
@ -98,7 +141,6 @@ class QuantLinear(nn.Module):
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
# self.padding = - outfeatures % 32
|
||||
#
|
||||
# self.infeatures = infeatures
|
||||
# self.outfeatures = outfeatures + self.padding
|
||||
@ -112,7 +154,8 @@ class QuantLinear(nn.Module):
|
||||
# assert infeatures % 32 == 0
|
||||
# assert infeatures % self.group_size == 0
|
||||
# assert outfeatures % 32 == 0
|
||||
#
|
||||
# self.padding = - outfeatures % 32
|
||||
|
||||
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||
# self.register_buffer(
|
||||
# 'qweight',
|
||||
@ -137,13 +180,16 @@ class QuantLinear(nn.Module):
|
||||
self.g_idx = g_idx
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
global FIXED_BYTES, LAYERS
|
||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||
LAYERS.append(self)
|
||||
|
||||
# if bias:
|
||||
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
# else:
|
||||
# self.bias = None
|
||||
|
||||
# def post_init(self, temp_dq):
|
||||
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8))
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
self.q_tensors = {
|
||||
@ -152,7 +198,7 @@ class QuantLinear(nn.Module):
|
||||
"scales":self.scales,
|
||||
"g_idx":self.g_idx
|
||||
}
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8))
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
self.q_handle = ext_make_q_matrix(
|
||||
self.q_tensors, temp_dq
|
||||
)
|
||||
@ -181,12 +227,12 @@ class ExLlamaV2DeviceTensors:
|
||||
scratch_idx: int
|
||||
scratch: torch.tensor = None
|
||||
|
||||
def __init__(self, device_idx, scratch_bytes):
|
||||
self.device_idx = device_idx
|
||||
def __init__(self, device, scratch_bytes):
|
||||
self.device = device
|
||||
self.scratch_bytes = scratch_bytes
|
||||
|
||||
def prepare(self):
|
||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
|
||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
|
||||
|
||||
def get_scratch_slice(self, size_bytes):
|
||||
|
||||
|
@ -31,16 +31,27 @@ try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
from text_generation_server.utils.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear
|
||||
if V2:
|
||||
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
HAS_EXLLAMA = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -309,7 +320,7 @@ def get_linear(weight, bias, quantize):
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
|
@ -278,7 +278,7 @@ class Weights:
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
|
Loading…
Reference in New Issue
Block a user