mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing exllamav2.
This commit is contained in:
parent
fb64ce1040
commit
a61f432599
@ -0,0 +1,38 @@
|
|||||||
|
#ifndef _compat_gemm_cuh
|
||||||
|
#define _compat_gemm_cuh
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
|
||||||
|
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
|
||||||
|
// for symbols as hipblasHalf.
|
||||||
|
#include <hipblas/hipblas.h>
|
||||||
|
|
||||||
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||||
|
hipblasOperation_t transA,
|
||||||
|
hipblasOperation_t transB,
|
||||||
|
int m,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
const half* alpha,
|
||||||
|
const half* AP,
|
||||||
|
int lda,
|
||||||
|
const half* BP,
|
||||||
|
int ldb,
|
||||||
|
const half* beta,
|
||||||
|
half* CP,
|
||||||
|
int ldc) {
|
||||||
|
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||||
|
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(beta),
|
||||||
|
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||||
|
}
|
||||||
|
#define hipblasHgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||||
|
#define rocblas_operation_none HIPBLAS_OP_N
|
||||||
|
#define rocblas_hgemm __compat_hipblasHgemm
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
@ -18,34 +18,7 @@
|
|||||||
#include "q_gemm_kernel.cuh"
|
#include "q_gemm_kernel.cuh"
|
||||||
#include "q_gemm_kernel_gptq.cuh"
|
#include "q_gemm_kernel_gptq.cuh"
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#include "compat_gemm.cuh"
|
||||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
|
||||||
hipblasOperation_t transA,
|
|
||||||
hipblasOperation_t transB,
|
|
||||||
int m,
|
|
||||||
int n,
|
|
||||||
int k,
|
|
||||||
const half* alpha,
|
|
||||||
const half* AP,
|
|
||||||
int lda,
|
|
||||||
const half* BP,
|
|
||||||
int ldb,
|
|
||||||
const half* beta,
|
|
||||||
half* CP,
|
|
||||||
int ldc) {
|
|
||||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
|
||||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
|
||||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
|
||||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
|
||||||
reinterpret_cast<const hipblasHalf *>(beta),
|
|
||||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
|
||||||
}
|
|
||||||
#define hipblasHgemm __compat_hipblasHgemm
|
|
||||||
|
|
||||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
|
||||||
#define rocblas_operation_none HIPBLAS_OP_N
|
|
||||||
#define rocblas_hgemm __compat_hipblasHgemm
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void gemm_half_q_half_cuda_part
|
void gemm_half_q_half_cuda_part
|
||||||
(
|
(
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include "compat.cuh"
|
#include "compat.cuh"
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
{
|
{
|
||||||
half2 result = {};
|
half2 result = {};
|
||||||
|
@ -72,6 +72,8 @@ QMatrix::QMatrix
|
|||||||
{
|
{
|
||||||
cudaSetDevice(device);
|
cudaSetDevice(device);
|
||||||
|
|
||||||
|
failed = false;
|
||||||
|
|
||||||
cuda_q_weight = _q_weight;
|
cuda_q_weight = _q_weight;
|
||||||
cuda_q_perm = _q_perm;
|
cuda_q_perm = _q_perm;
|
||||||
cuda_q_invperm = _q_invperm;
|
cuda_q_invperm = _q_invperm;
|
||||||
@ -125,7 +127,15 @@ QMatrix::QMatrix
|
|||||||
rows_3 = height;
|
rows_3 = height;
|
||||||
rows_2 = height;
|
rows_2 = height;
|
||||||
|
|
||||||
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
if (_gptq_g_idx)
|
||||||
|
{
|
||||||
|
if (!make_sequential(_gptq_g_idx))
|
||||||
|
{
|
||||||
|
failed = true;
|
||||||
|
//printf("FAIL\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shuffle quantized data
|
// Shuffle quantized data
|
||||||
@ -139,6 +149,9 @@ QMatrix::QMatrix
|
|||||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QMatrix::~QMatrix()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
// Reconstruct b[k,n] (GPTQ)
|
// Reconstruct b[k,n] (GPTQ)
|
||||||
|
|
||||||
@ -437,11 +450,11 @@ void QMatrix::reconstruct(half* out)
|
|||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
|
||||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
if (!is_gptq)
|
if (!is_gptq)
|
||||||
{
|
{
|
||||||
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
@ -464,6 +477,7 @@ void QMatrix::reconstruct(half* out)
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
@ -523,10 +537,14 @@ __global__ void make_sequential_kernel
|
|||||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
{
|
{
|
||||||
uint32_t* cuda_new_qweight = NULL;
|
uint32_t* cuda_new_qweight = NULL;
|
||||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||||
@ -600,4 +618,6 @@ void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
|||||||
free(cpu_g_idx_map);
|
free(cpu_g_idx_map);
|
||||||
free(cpu_x_map);
|
free(cpu_x_map);
|
||||||
free(cpu_x_map_inv);
|
free(cpu_x_map_inv);
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,8 @@ public:
|
|||||||
|
|
||||||
half* temp_dq;
|
half* temp_dq;
|
||||||
|
|
||||||
|
bool failed;
|
||||||
|
|
||||||
QMatrix
|
QMatrix
|
||||||
(
|
(
|
||||||
const int _device,
|
const int _device,
|
||||||
@ -62,7 +64,7 @@ public:
|
|||||||
~QMatrix();
|
~QMatrix();
|
||||||
|
|
||||||
void reconstruct(half* out);
|
void reconstruct(half* out);
|
||||||
void make_sequential(const uint32_t* cpu_g_idx);
|
bool make_sequential(const uint32_t* cpu_g_idx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
@ -30,3 +30,13 @@ __forceinline__ __device__ float clamp(float x, float a, float b)
|
|||||||
{
|
{
|
||||||
return fmaxf(a, fminf(b, x));
|
return fmaxf(a, fminf(b, x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
|
||||||
|
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
|
||||||
|
{
|
||||||
|
if (code != cudaSuccess)
|
||||||
|
{
|
||||||
|
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||||
|
if (abort) exit(code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -168,7 +168,7 @@ def serve(
|
|||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
# For which we have the finale shapes only after the model has loaded
|
# For which we have the finale shapes only after the model has loaded
|
||||||
# This will allocate those buffers.
|
# This will allocate those buffers.
|
||||||
from text_generation_server.utils.gptq.exllama import (
|
from text_generation_server.utils.layers import (
|
||||||
create_exllama_buffers,
|
create_exllama_buffers,
|
||||||
set_device,
|
set_device,
|
||||||
)
|
)
|
||||||
|
@ -17,10 +17,6 @@ except ImportError:
|
|||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
none_tensor = torch.empty((1, 1), device="meta")
|
none_tensor = torch.empty((1, 1), device="meta")
|
||||||
|
|
||||||
def _torch_device(idx):
|
|
||||||
if idx == -1: return "cpu"
|
|
||||||
return f"cuda:{idx}"
|
|
||||||
|
|
||||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
output_shape = x.shape[:-1] + (q4_width,)
|
output_shape = x.shape[:-1] + (q4_width,)
|
||||||
@ -82,6 +78,53 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
none_tensor,
|
none_tensor,
|
||||||
temp_dq)
|
temp_dq)
|
||||||
|
|
||||||
|
DEVICE = None
|
||||||
|
FIXED_BYTES = 0
|
||||||
|
LAYERS = []
|
||||||
|
|
||||||
|
|
||||||
|
def set_device(device):
|
||||||
|
global DEVICE
|
||||||
|
DEVICE = device
|
||||||
|
|
||||||
|
|
||||||
|
def create_exllama_buffers():
|
||||||
|
global FIXED_BYTES, LAYERS, DEVICE
|
||||||
|
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
||||||
|
|
||||||
|
for layer in LAYERS:
|
||||||
|
layer.post_init(temp_dq)
|
||||||
|
|
||||||
|
|
||||||
|
# assert DEVICE is not None, "call set_device first"
|
||||||
|
|
||||||
|
# if ACT_ORDER:
|
||||||
|
# # TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||||
|
# # does not offer an API to expose this variable to python, as this variable
|
||||||
|
# # is handled by the client but it appears the model is initialized by the server.
|
||||||
|
# # An alternative could be to initialize the buffers during warmup.
|
||||||
|
# # Dummy
|
||||||
|
# max_total_tokens = 2048
|
||||||
|
# else:
|
||||||
|
# max_total_tokens = 1
|
||||||
|
|
||||||
|
# # This temp_state buffer is required to reorder X in the act-order case.
|
||||||
|
# temp_state = torch.zeros(
|
||||||
|
# (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
||||||
|
# )
|
||||||
|
# temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||||
|
|
||||||
|
# # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
|
# prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||||
|
|
||||||
|
# matmul_recons_thd = 8
|
||||||
|
# matmul_fused_remap = False
|
||||||
|
# matmul_no_half2 = False
|
||||||
|
# set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
# TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||||
|
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
class QuantLinear(nn.Module):
|
||||||
QUANT_TYPE = "exllamav2"
|
QUANT_TYPE = "exllamav2"
|
||||||
|
|
||||||
@ -98,7 +141,6 @@ class QuantLinear(nn.Module):
|
|||||||
|
|
||||||
self.q_handle = None
|
self.q_handle = None
|
||||||
self.q_tensors = None
|
self.q_tensors = None
|
||||||
# self.padding = - outfeatures % 32
|
|
||||||
#
|
#
|
||||||
# self.infeatures = infeatures
|
# self.infeatures = infeatures
|
||||||
# self.outfeatures = outfeatures + self.padding
|
# self.outfeatures = outfeatures + self.padding
|
||||||
@ -112,7 +154,8 @@ class QuantLinear(nn.Module):
|
|||||||
# assert infeatures % 32 == 0
|
# assert infeatures % 32 == 0
|
||||||
# assert infeatures % self.group_size == 0
|
# assert infeatures % self.group_size == 0
|
||||||
# assert outfeatures % 32 == 0
|
# assert outfeatures % 32 == 0
|
||||||
#
|
# self.padding = - outfeatures % 32
|
||||||
|
|
||||||
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||||
# self.register_buffer(
|
# self.register_buffer(
|
||||||
# 'qweight',
|
# 'qweight',
|
||||||
@ -137,13 +180,16 @@ class QuantLinear(nn.Module):
|
|||||||
self.g_idx = g_idx
|
self.g_idx = g_idx
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
global FIXED_BYTES, LAYERS
|
||||||
|
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||||
|
LAYERS.append(self)
|
||||||
|
|
||||||
# if bias:
|
# if bias:
|
||||||
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
# else:
|
# else:
|
||||||
# self.bias = None
|
# self.bias = None
|
||||||
|
|
||||||
# def post_init(self, temp_dq):
|
def post_init(self, temp_dq):
|
||||||
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8))
|
|
||||||
assert self.qweight.device.type == "cuda"
|
assert self.qweight.device.type == "cuda"
|
||||||
assert self.qweight.device.index is not None
|
assert self.qweight.device.index is not None
|
||||||
self.q_tensors = {
|
self.q_tensors = {
|
||||||
@ -152,7 +198,7 @@ class QuantLinear(nn.Module):
|
|||||||
"scales":self.scales,
|
"scales":self.scales,
|
||||||
"g_idx":self.g_idx
|
"g_idx":self.g_idx
|
||||||
}
|
}
|
||||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8))
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||||
self.q_handle = ext_make_q_matrix(
|
self.q_handle = ext_make_q_matrix(
|
||||||
self.q_tensors, temp_dq
|
self.q_tensors, temp_dq
|
||||||
)
|
)
|
||||||
@ -181,12 +227,12 @@ class ExLlamaV2DeviceTensors:
|
|||||||
scratch_idx: int
|
scratch_idx: int
|
||||||
scratch: torch.tensor = None
|
scratch: torch.tensor = None
|
||||||
|
|
||||||
def __init__(self, device_idx, scratch_bytes):
|
def __init__(self, device, scratch_bytes):
|
||||||
self.device_idx = device_idx
|
self.device = device
|
||||||
self.scratch_bytes = scratch_bytes
|
self.scratch_bytes = scratch_bytes
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
|
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
|
||||||
|
|
||||||
def get_scratch_slice(self, size_bytes):
|
def get_scratch_slice(self, size_bytes):
|
||||||
|
|
||||||
|
@ -31,16 +31,27 @@ try:
|
|||||||
major, _minor = torch.cuda.get_device_capability()
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
except Exception:
|
except Exception:
|
||||||
major = 1
|
major = 1
|
||||||
|
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8
|
||||||
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
elif CAN_EXLLAMA:
|
elif CAN_EXLLAMA:
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
if V2:
|
||||||
from text_generation_server.utils.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear
|
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
||||||
|
create_exllama_buffers,
|
||||||
|
set_device,
|
||||||
|
)
|
||||||
|
HAS_EXLLAMA = "2"
|
||||||
|
else:
|
||||||
|
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
||||||
|
create_exllama_buffers,
|
||||||
|
set_device,
|
||||||
|
)
|
||||||
|
HAS_EXLLAMA = "1"
|
||||||
|
|
||||||
HAS_EXLLAMA = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -309,7 +320,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
linear = QuantLinear(
|
linear = QuantLinear(
|
||||||
qweight,
|
qweight,
|
||||||
|
@ -278,7 +278,7 @@ class Weights:
|
|||||||
)
|
)
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
else:
|
else:
|
||||||
logger.info("Using exllama kernels")
|
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if groupsize >= 0:
|
if groupsize >= 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user