mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
wip
This commit is contained in:
parent
50a20a83d7
commit
33111a0e7f
5
.gitignore
vendored
5
.gitignore
vendored
@ -2,3 +2,8 @@
|
|||||||
target
|
target
|
||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
# ROCm auto-generated files
|
||||||
|
*.hip
|
||||||
|
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||||
|
server/exllama_kernels/exllama_kernels/hip/
|
||||||
|
@ -104,6 +104,20 @@ WORKDIR /usr/src
|
|||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||||
|
|
||||||
|
# Build exllama kernels
|
||||||
|
FROM kernel-builder as exllama-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
|
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||||
|
|
||||||
|
# Build exllama v2 kernels
|
||||||
|
FROM kernel-builder as exllamav2-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
|
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||||
|
|
||||||
FROM base as base-copy
|
FROM base as base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
|
|||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from exllama kernels builder
|
||||||
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
@ -149,5 +169,5 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
#ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
#CMD ["--json-output"]
|
||||||
|
@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
|||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
#if __CUDA_ARCH__ < 700
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < 600
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -2,8 +2,11 @@
|
|||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
#include "../matrix.cuh"
|
#include "../matrix.cuh"
|
||||||
#include "../cuda_compat.cuh"
|
#include "../cu_compat.cuh"
|
||||||
#include "../cuda_buffers.cuh"
|
#include "../cuda_buffers.cuh"
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
#include "../hip_compat.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||||
@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
|
|||||||
|
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half result = __hadd(acc.x, acc.y);
|
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
250
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
Normal file
250
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include "util_hip.cuh"
|
||||||
|
#include "tuning.h"
|
||||||
|
#include "hip_buffers.cuh"
|
||||||
|
#include "hip_func/q4_matrix.cuh"
|
||||||
|
#include "hip_func/q4_matmul.cuh"
|
||||||
|
#include "hip_func/column_remap.cuh"
|
||||||
|
|
||||||
|
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||||
|
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||||
|
// exceptions, CUDA functions return with a hipError_t which we can parse and dump to the console.
|
||||||
|
|
||||||
|
void check_cuda(hipError_t ret)
|
||||||
|
{
|
||||||
|
switch (ret)
|
||||||
|
{
|
||||||
|
case hipSuccess:
|
||||||
|
break;
|
||||||
|
|
||||||
|
case cudaUnspecified:
|
||||||
|
printf(" **** Unspecified error\n");
|
||||||
|
TORCH_CHECK(false, "CUDA error");
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
printf(" **** CUDA error\n"); \
|
||||||
|
printf(" **** %s\n", hipGetErrorString(ret)); \
|
||||||
|
TORCH_CHECK(false, "CUDA error"); \
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some decluttering macros
|
||||||
|
|
||||||
|
#define STRINGIFY_(__x) #__x
|
||||||
|
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||||
|
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||||
|
|
||||||
|
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||||
|
do { \
|
||||||
|
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||||
|
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||||
|
do { \
|
||||||
|
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||||
|
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||||
|
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||||
|
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||||
|
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||||
|
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||||
|
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||||
|
{
|
||||||
|
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||||
|
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||||
|
return groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Tuning parameters
|
||||||
|
|
||||||
|
ExLlamaTuning tuningParams;
|
||||||
|
|
||||||
|
void set_tuning_params
|
||||||
|
(
|
||||||
|
int matmul_recons_thd,
|
||||||
|
bool matmul_fused_remap,
|
||||||
|
bool matmul_no_half2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||||
|
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||||
|
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Release all unmanaged objects allocated by the extension
|
||||||
|
|
||||||
|
void cleanup()
|
||||||
|
{
|
||||||
|
cleanup_buffers_cuda();
|
||||||
|
g_q4_free_matrices();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Prepare buffers for forward pass
|
||||||
|
|
||||||
|
void prepare_buffers
|
||||||
|
(
|
||||||
|
torch::Device device,
|
||||||
|
torch::Tensor temp_state,
|
||||||
|
torch::Tensor temp_dq
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int device_index = device.index();
|
||||||
|
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||||
|
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device);
|
||||||
|
|
||||||
|
prepare_buffers_cuda
|
||||||
|
(
|
||||||
|
device_index,
|
||||||
|
(half*) temp_state.data_ptr(),
|
||||||
|
(half*) temp_dq.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Create Q4Matrix, return handle
|
||||||
|
|
||||||
|
uintptr_t make_q4
|
||||||
|
(
|
||||||
|
torch::Tensor qweight,
|
||||||
|
torch::Tensor qzeros,
|
||||||
|
torch::Tensor scales,
|
||||||
|
torch::Tensor g_idx,
|
||||||
|
int device
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||||
|
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||||
|
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||||
|
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||||
|
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||||
|
|
||||||
|
int width = qweight.size(1);
|
||||||
|
int height = qweight.size(0) * 8;
|
||||||
|
int groups = qzeros.size(0);
|
||||||
|
|
||||||
|
Q4Matrix* m = new Q4Matrix
|
||||||
|
(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
|
||||||
|
(uint32_t*) qweight.data_ptr(),
|
||||||
|
(uint32_t*) qzeros.data_ptr(),
|
||||||
|
(half*) scales.data_ptr(),
|
||||||
|
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||||
|
|
||||||
|
device
|
||||||
|
);
|
||||||
|
|
||||||
|
g_q4_keep_matrix(m);
|
||||||
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Matmul half @ quant -> half
|
||||||
|
|
||||||
|
void q4_matmul
|
||||||
|
(
|
||||||
|
torch::Tensor x,
|
||||||
|
uintptr_t w,
|
||||||
|
torch::Tensor out
|
||||||
|
)
|
||||||
|
{
|
||||||
|
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||||
|
|
||||||
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(out, kHalf);
|
||||||
|
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||||
|
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||||
|
|
||||||
|
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x));
|
||||||
|
|
||||||
|
int x_height = x.size(0);
|
||||||
|
|
||||||
|
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||||
|
{
|
||||||
|
q4_matmul_cuda
|
||||||
|
(
|
||||||
|
&tuningParams,
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
x_height,
|
||||||
|
wm,
|
||||||
|
(half*) out.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
q4_matmul_recons_cuda
|
||||||
|
(
|
||||||
|
&tuningParams,
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
x_height,
|
||||||
|
wm,
|
||||||
|
(half*) out.data_ptr(),
|
||||||
|
at::cuda::getCurrentCUDABlasHandle()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Remap columns in half tensor
|
||||||
|
|
||||||
|
void column_remap
|
||||||
|
(
|
||||||
|
torch::Tensor x,
|
||||||
|
torch::Tensor x_new,
|
||||||
|
torch::Tensor x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||||
|
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||||
|
|
||||||
|
int height = x.size(0);
|
||||||
|
int width = x.size(1);
|
||||||
|
|
||||||
|
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x));
|
||||||
|
|
||||||
|
column_remap_cuda
|
||||||
|
(
|
||||||
|
(half*) x.data_ptr(),
|
||||||
|
(half*) x_new.data_ptr(),
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
(uint32_t*) x_map.data_ptr()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
|
{
|
||||||
|
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||||
|
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||||
|
m.def("cleanup", &cleanup, "cleanup");
|
||||||
|
m.def("make_q4", &make_q4, "make_q4");
|
||||||
|
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||||
|
}
|
53
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
Normal file
53
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _cuda_buffers_cuh
|
||||||
|
#define _cuda_buffers_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
const int CUDA_MAX_DEVICES = 16;
|
||||||
|
|
||||||
|
// #ifndef _cuda_buffers_cu
|
||||||
|
// extern __constant__ half2 q4_table[16][256];
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
class CudaBuffers
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
int device;
|
||||||
|
|
||||||
|
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||||
|
half* temp_dq; // size of largest quant tensor * 8
|
||||||
|
|
||||||
|
hipStream_t alt_stream_1;
|
||||||
|
hipStream_t alt_stream_2;
|
||||||
|
hipStream_t alt_stream_3;
|
||||||
|
hipEvent_t alt_stream_1_done;
|
||||||
|
hipEvent_t alt_stream_2_done;
|
||||||
|
hipEvent_t alt_stream_3_done;
|
||||||
|
|
||||||
|
CudaBuffers
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
);
|
||||||
|
~CudaBuffers();
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaBuffers* get_buffers(const int device_index);
|
||||||
|
|
||||||
|
void prepare_buffers_cuda
|
||||||
|
(
|
||||||
|
int _device,
|
||||||
|
half* _temp_state,
|
||||||
|
half* _temp_dq
|
||||||
|
);
|
||||||
|
|
||||||
|
void cleanup_buffers_cuda();
|
||||||
|
|
||||||
|
#endif
|
51
server/exllama_kernels/exllama_kernels/hip_compat.cuh
Normal file
51
server/exllama_kernels/exllama_kernels/hip_compat.cuh
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _hip_compat_cuh
|
||||||
|
#define _hip_compat_cuh
|
||||||
|
|
||||||
|
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
|
||||||
|
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||||
|
return __half_raw{
|
||||||
|
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||||
|
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||||
|
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||||
|
}
|
||||||
|
|
||||||
|
#define hrcp __compat_hrcp
|
||||||
|
#define h2rcp __compat_h2rcp
|
||||||
|
|
||||||
|
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
|
||||||
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||||
|
hipblasOperation_t transA,
|
||||||
|
hipblasOperation_t transB,
|
||||||
|
int m,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
const half* alpha,
|
||||||
|
const half* AP,
|
||||||
|
int lda,
|
||||||
|
const half* BP,
|
||||||
|
int ldb,
|
||||||
|
const half* beta,
|
||||||
|
half* CP,
|
||||||
|
int ldc) {
|
||||||
|
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||||
|
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(beta),
|
||||||
|
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||||
|
}
|
||||||
|
#define hipblasHgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||||
|
#define rocblas_handle hipblasHandle_t
|
||||||
|
#define rocblas_operation_none HIPBLAS_OP_N
|
||||||
|
#define rocblas_get_stream hipblasGetStream
|
||||||
|
#define rocblas_set_stream hipblasSetStream
|
||||||
|
#define rocblas_hgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,20 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _column_remap_cuh
|
||||||
|
#define _column_remap_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
void column_remap_cuda
|
||||||
|
(
|
||||||
|
const half* x,
|
||||||
|
half* x_new,
|
||||||
|
const int x_height,
|
||||||
|
const int x_width,
|
||||||
|
const uint32_t* x_map
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,38 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _q4_matmul_cuh
|
||||||
|
#define _q4_matmul_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
|
||||||
|
#include "../hip_func/q4_matrix.cuh"
|
||||||
|
#include "../tuning.h"
|
||||||
|
|
||||||
|
void q4_matmul_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
const Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
bool no_zero = false,
|
||||||
|
hipStream_t alt_stream = NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
void q4_matmul_recons_cuda
|
||||||
|
(
|
||||||
|
ExLlamaTuning* tuningParams,
|
||||||
|
const half* x,
|
||||||
|
const int x_height,
|
||||||
|
Q4Matrix* w,
|
||||||
|
half* out,
|
||||||
|
const hipblasHandle_t handle,
|
||||||
|
bool no_zero = false
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,54 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _q4_matrix_cuh
|
||||||
|
#define _q4_matrix_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
class Q4Matrix
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
int device;
|
||||||
|
|
||||||
|
int height;
|
||||||
|
int width;
|
||||||
|
int groups;
|
||||||
|
int groupsize;
|
||||||
|
|
||||||
|
uint32_t* cuda_qweight = NULL;
|
||||||
|
uint32_t* cuda_qzeros = NULL;
|
||||||
|
half* cuda_scales = NULL;
|
||||||
|
uint32_t* cuda_x_map = NULL;
|
||||||
|
|
||||||
|
Q4Matrix
|
||||||
|
(
|
||||||
|
const int _height,
|
||||||
|
const int _width,
|
||||||
|
const int _groups,
|
||||||
|
|
||||||
|
uint32_t* _qweight,
|
||||||
|
uint32_t* _qzeros,
|
||||||
|
half* _scales,
|
||||||
|
uint32_t* _g_idx,
|
||||||
|
|
||||||
|
const int _device
|
||||||
|
);
|
||||||
|
|
||||||
|
~Q4Matrix();
|
||||||
|
|
||||||
|
void reconstruct(half* out);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void make_sequential(const uint32_t* cpu_g_idx);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
void g_q4_keep_matrix(Q4Matrix* m);
|
||||||
|
void g_q4_free_matrices();
|
||||||
|
|
||||||
|
#endif
|
295
server/exllama_kernels/exllama_kernels/matrix_hip.cuh
Normal file
295
server/exllama_kernels/exllama_kernels/matrix_hip.cuh
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _matrix_cuh
|
||||||
|
#define _matrix_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
class MatrixView_half
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_half_rw
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_column
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (row & 0x07) * 4;
|
||||||
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||||
|
|
||||||
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||||
|
|
||||||
|
__device__ __forceinline__ half2 dot_product_8
|
||||||
|
(
|
||||||
|
const half2 acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half2 v_scale_2,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half2 result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half2 v_01 = __halves2half2(v_0, v_1);
|
||||||
|
half2 v_23 = __halves2half2(v_2, v_3);
|
||||||
|
half2 v_45 = __halves2half2(v_4, v_5);
|
||||||
|
half2 v_67 = __halves2half2(v_6, v_7);
|
||||||
|
|
||||||
|
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||||
|
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||||
|
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||||
|
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||||
|
|
||||||
|
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||||
|
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||||
|
result = __hfma2(v_scale_2, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ half dot_product_8_h
|
||||||
|
(
|
||||||
|
const half acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half v_scale,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half tmp = __hmul(*h_ptr++, v_0);
|
||||||
|
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||||
|
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||||
|
result = __hfma(v_scale, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||||
|
|
||||||
|
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||||
|
(
|
||||||
|
const half2 acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half2 v_scale_2,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||||
|
const uint32_t* x_map_ptr = x_map + h_column;
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half2 result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half2 v_01 = __halves2half2(v_0, v_1);
|
||||||
|
half2 v_23 = __halves2half2(v_2, v_3);
|
||||||
|
half2 v_45 = __halves2half2(v_4, v_5);
|
||||||
|
half2 v_67 = __halves2half2(v_6, v_7);
|
||||||
|
|
||||||
|
half h_0 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_1 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_2 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_3 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_4 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_5 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_6 = h_ptr[*x_map_ptr++];
|
||||||
|
half h_7 = h_ptr[*x_map_ptr++];
|
||||||
|
|
||||||
|
half2 h_01 = __halves2half2(h_0, h_1);
|
||||||
|
half2 h_23 = __halves2half2(h_2, h_3);
|
||||||
|
half2 h_45 = __halves2half2(h_4, h_5);
|
||||||
|
half2 h_67 = __halves2half2(h_6, h_7);
|
||||||
|
|
||||||
|
half2 tmp = __hmul2(h_01, v_01);
|
||||||
|
tmp = __hfma2(h_23, v_23, tmp);
|
||||||
|
tmp = __hfma2(h_45, v_45, tmp);
|
||||||
|
tmp = __hfma2(h_67, v_67, tmp);
|
||||||
|
result = __hfma2(v_scale_2, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||||
|
(
|
||||||
|
const half acc,
|
||||||
|
MatrixView_half& h_,
|
||||||
|
const int h_row,
|
||||||
|
const int h_column, // divisible by 8
|
||||||
|
MatrixView_q4_column& v_,
|
||||||
|
const int v_row, // divisible by 8
|
||||||
|
const int v_column,
|
||||||
|
const half v_scale,
|
||||||
|
const uint32_t v_zero, // + 1 (!!)
|
||||||
|
const int count,
|
||||||
|
const uint32_t* x_map
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||||
|
const uint32_t* x_map_ptr = x_map + h_column;
|
||||||
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
|
half result = acc;
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||||
|
|
||||||
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||||
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||||
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||||
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||||
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||||
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||||
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||||
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||||
|
|
||||||
|
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||||
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||||
|
result = __hfma(v_scale, tmp, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@ -8,7 +8,11 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
#define cudaUnspecified hipErrorUnknown
|
||||||
|
#else
|
||||||
#define cudaUnspecified cudaErrorApiFailureBase
|
#define cudaUnspecified cudaErrorApiFailureBase
|
||||||
|
#endif
|
||||||
|
|
||||||
// React to failure on return code != cudaSuccess
|
// React to failure on return code != cudaSuccess
|
||||||
|
|
||||||
|
34
server/exllama_kernels/exllama_kernels/util_hip.cuh
Normal file
34
server/exllama_kernels/exllama_kernels/util_hip.cuh
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#ifndef _util_cuh
|
||||||
|
#define _util_cuh
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
#define cudaUnspecified hipErrorUnknown
|
||||||
|
#else
|
||||||
|
#define cudaUnspecified hipErrorApiFailureBase
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// React to failure on return code != hipSuccess
|
||||||
|
|
||||||
|
#define _cuda_check(fn) \
|
||||||
|
do { \
|
||||||
|
{_cuda_err = fn;} \
|
||||||
|
if (_cuda_err != hipSuccess) goto _cuda_fail; \
|
||||||
|
} while(false)
|
||||||
|
|
||||||
|
// React to failure on return code == 0
|
||||||
|
|
||||||
|
#define _alloc_check(fn) \
|
||||||
|
do { \
|
||||||
|
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||||
|
else _cuda_err = hipSuccess; \
|
||||||
|
} while(false)
|
||||||
|
|
||||||
|
#endif
|
@ -22,6 +22,11 @@
|
|||||||
|
|
||||||
#include "q_gemm_kernel.cuh"
|
#include "q_gemm_kernel.cuh"
|
||||||
#include "q_gemm_kernel_gptq.cuh"
|
#include "q_gemm_kernel_gptq.cuh"
|
||||||
|
#include "compat_gemm.cuh"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
void gemm_half_q_half_cuda_part
|
void gemm_half_q_half_cuda_part
|
||||||
(
|
(
|
||||||
@ -38,8 +43,11 @@ void gemm_half_q_half_cuda_part
|
|||||||
bool mul_r_weights
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
ofstream myfile;
|
||||||
|
myfile.open ("/tgi/server/exllamav2_kernels/log.txt");
|
||||||
if (!b->is_gptq)
|
if (!b->is_gptq)
|
||||||
{
|
{
|
||||||
|
myfile << "go in is_gptq path" << "\n";
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
@ -50,6 +58,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
|
|
||||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
|
myfile << "launch kernel" << "\n";
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
@ -110,6 +119,7 @@ void gemm_half_q_half_cuda_part
|
|||||||
r_weights_stride
|
r_weights_stride
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
myfile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm_half_q_half_cuda
|
void gemm_half_q_half_cuda
|
||||||
@ -129,8 +139,11 @@ void gemm_half_q_half_cuda
|
|||||||
bool mul_r_weights
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
ofstream myfile;
|
||||||
|
myfile.open ("/tgi/server/exllamav2_kernels/log.txt");
|
||||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||||
{
|
{
|
||||||
|
myfile << "going in cublas path" << "\n";
|
||||||
// Reconstruct FP16 matrix, then cuBLAS
|
// Reconstruct FP16 matrix, then cuBLAS
|
||||||
|
|
||||||
if (!temp_dq) temp_dq = b->temp_dq;
|
if (!temp_dq) temp_dq = b->temp_dq;
|
||||||
@ -172,6 +185,9 @@ void gemm_half_q_half_cuda
|
|||||||
{
|
{
|
||||||
// Quantized matmul
|
// Quantized matmul
|
||||||
|
|
||||||
|
myfile << "going in gemm_half_q_half_cuda_part path" << "\n";
|
||||||
|
|
||||||
|
|
||||||
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
||||||
int max_chunks = size_m / block_m_size_max;
|
int max_chunks = size_m / block_m_size_max;
|
||||||
int last_chunk = max_chunks * block_m_size_max;
|
int last_chunk = max_chunks * block_m_size_max;
|
||||||
@ -179,14 +195,17 @@ void gemm_half_q_half_cuda
|
|||||||
|
|
||||||
if (max_chunks)
|
if (max_chunks)
|
||||||
{
|
{
|
||||||
|
myfile << "call gemm_half_q_half_cuda_part max_chunks" << "\n";
|
||||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (last_chunk_size)
|
if (last_chunk_size)
|
||||||
{
|
{
|
||||||
|
myfile << "call gemm_half_q_half_cuda_part last_chunk_size" << "\n";
|
||||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
myfile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void clear_kernel
|
__global__ void clear_kernel
|
||||||
|
@ -13,6 +13,10 @@
|
|||||||
|
|
||||||
#include "cpp/util.h"
|
#include "cpp/util.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
// Some decluttering macros
|
// Some decluttering macros
|
||||||
|
|
||||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
@ -105,6 +109,10 @@ void gemm_half_q_half
|
|||||||
bool force_cuda
|
bool force_cuda
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
ofstream myfile;
|
||||||
|
myfile.open ("/tgi/server/exllamav2_kernels/log.txt");
|
||||||
|
myfile << "start gemm_half_q_half" << "\n";
|
||||||
|
|
||||||
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||||
|
|
||||||
TORCH_CHECK_DTYPE(a, kHalf);
|
TORCH_CHECK_DTYPE(a, kHalf);
|
||||||
@ -115,6 +123,7 @@ void gemm_half_q_half
|
|||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
|
|
||||||
|
myfile << "call gemm_half_q_half_cuda" << "\n";
|
||||||
gemm_half_q_half_cuda
|
gemm_half_q_half_cuda
|
||||||
(
|
(
|
||||||
at::cuda::getCurrentCUDABlasHandle(),
|
at::cuda::getCurrentCUDABlasHandle(),
|
||||||
@ -128,6 +137,7 @@ void gemm_half_q_half
|
|||||||
NULL,
|
NULL,
|
||||||
force_cuda
|
force_cuda
|
||||||
);
|
);
|
||||||
|
myfile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bindings
|
// Bindings
|
||||||
|
150
server/exllamav2_kernels/exllamav2_kernels/ext_hip.cpp
Normal file
150
server/exllamav2_kernels/exllamav2_kernels/ext_hip.cpp
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#include "config.h"
|
||||||
|
|
||||||
|
#include "hip/q_matrix.cuh"
|
||||||
|
#include "hip/q_gemm.cuh"
|
||||||
|
|
||||||
|
#include "cpp/util.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
// Some decluttering macros
|
||||||
|
|
||||||
|
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
|
||||||
|
|
||||||
|
// Quant matrix
|
||||||
|
|
||||||
|
uintptr_t make_q_matrix
|
||||||
|
(
|
||||||
|
torch::Tensor q_weight,
|
||||||
|
torch::Tensor q_perm,
|
||||||
|
torch::Tensor q_invperm,
|
||||||
|
torch::Tensor q_scale,
|
||||||
|
torch::Tensor q_scale_max,
|
||||||
|
torch::Tensor q_groups,
|
||||||
|
torch::Tensor q_group_map,
|
||||||
|
torch::Tensor gptq_qzeros,
|
||||||
|
torch::Tensor gptq_scales,
|
||||||
|
torch::Tensor gptq_g_idx,
|
||||||
|
torch::Tensor temp_dq
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||||
|
|
||||||
|
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||||
|
|
||||||
|
int device = q_weight.device().index();
|
||||||
|
int width = q_weight.size(1);
|
||||||
|
int groups;
|
||||||
|
int height;
|
||||||
|
|
||||||
|
if (!q_scale.device().is_meta())
|
||||||
|
{
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||||
|
groups = q_scale.size(0);
|
||||||
|
height = q_invperm.size(0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||||
|
groups = gptq_qzeros.size(0);
|
||||||
|
height = q_weight.size(0) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||||
|
|
||||||
|
QMatrix* m = new QMatrix
|
||||||
|
(
|
||||||
|
device,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
(uint32_t*) q_weight.data_ptr(),
|
||||||
|
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||||
|
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||||
|
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||||
|
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||||
|
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||||
|
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
|
||||||
|
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||||
|
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||||
|
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||||
|
(half*) temp_dq.data_ptr()
|
||||||
|
);
|
||||||
|
|
||||||
|
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||||
|
|
||||||
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_half_q_half
|
||||||
|
(
|
||||||
|
torch::Tensor a,
|
||||||
|
uintptr_t b,
|
||||||
|
torch::Tensor c,
|
||||||
|
bool force_cuda
|
||||||
|
)
|
||||||
|
{
|
||||||
|
ofstream myfile;
|
||||||
|
myfile.open ("/tgi/server/exllamav2_kernels/log.txt");
|
||||||
|
myfile << "start gemm_half_q_half" << "\n";
|
||||||
|
|
||||||
|
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||||
|
|
||||||
|
TORCH_CHECK_DTYPE(a, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(c, kHalf);
|
||||||
|
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||||
|
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||||
|
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||||
|
|
||||||
|
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(a));
|
||||||
|
|
||||||
|
myfile << "call gemm_half_q_half_cuda" << "\n";
|
||||||
|
gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
at::cuda::getCurrentCUDABlasHandle(),
|
||||||
|
(const half*) a.data_ptr(),
|
||||||
|
qm,
|
||||||
|
(half*) c.data_ptr(),
|
||||||
|
c.size(0), // m
|
||||||
|
c.size(1), // n
|
||||||
|
a.size(1), // k
|
||||||
|
true,
|
||||||
|
NULL,
|
||||||
|
force_cuda
|
||||||
|
);
|
||||||
|
myfile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
|
{
|
||||||
|
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||||
|
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||||
|
}
|
@ -31,6 +31,7 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -679,8 +680,12 @@ class FlashCausalLM(Model):
|
|||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
|
logger.info("in this warmup start")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
try:
|
logger.info("in this warmup after empty cache")
|
||||||
|
|
||||||
|
#try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
@ -690,15 +695,20 @@ class FlashCausalLM(Model):
|
|||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
logger.info("in this warmup after set_cache_manager")
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
logger.info("in this warmup after generate_token")
|
||||||
raise RuntimeError(
|
# except torch.cuda.OutOfMemoryError as e:
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
# raise RuntimeError(
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
# f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
) from e
|
# f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
|
# ) from e
|
||||||
|
|
||||||
|
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
|
|
||||||
|
logger.info("in this warmup after sync")
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
@ -818,11 +828,14 @@ class FlashCausalLM(Model):
|
|||||||
batch.block_tables_tensor = block_tables_tensor
|
batch.block_tables_tensor = block_tables_tensor
|
||||||
batch.slots = slots
|
batch.slots = slots
|
||||||
|
|
||||||
try:
|
logger.info("callign forward in generate_token")
|
||||||
out = self.forward(batch)
|
out = self.forward(batch)
|
||||||
except Exception as e:
|
# try:
|
||||||
del batch
|
# out = self.forward(batch)
|
||||||
raise e
|
# except Exception as e:
|
||||||
|
# del batch
|
||||||
|
# raise e
|
||||||
|
logger.info("finished forward in generate_token")
|
||||||
|
|
||||||
if isinstance(out, tuple):
|
if isinstance(out, tuple):
|
||||||
out, speculative_logits = out
|
out, speculative_logits = out
|
||||||
|
@ -71,6 +71,7 @@ class Model(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
|
logger.info("in this warmup model.py")
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -63,6 +63,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
logger.info("IN WARMUP")
|
||||||
if self.quantize == "gptq":
|
if self.quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
@ -78,6 +79,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
logger.info("after quantize == gptq")
|
||||||
if (
|
if (
|
||||||
self.model.batch_type == IdeficsCausalLMBatch
|
self.model.batch_type == IdeficsCausalLMBatch
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
@ -92,8 +94,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
logger.info("calling model.warmup")
|
||||||
max_supported_total_tokens = self.model.warmup(batch)
|
max_supported_total_tokens = self.model.warmup(batch)
|
||||||
|
|
||||||
|
logger.info("end warmup")
|
||||||
return generate_pb2.WarmupResponse(
|
return generate_pb2.WarmupResponse(
|
||||||
max_supported_total_tokens=max_supported_total_tokens
|
max_supported_total_tokens=max_supported_total_tokens
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||||
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
import math
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
from loguru import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
@ -23,6 +21,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|||||||
output_shape = x.shape[:-1] + (q4_width,)
|
output_shape = x.shape[:-1] + (q4_width,)
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
||||||
|
logger.info("calling gemm_half_q_half")
|
||||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||||
return output.view(output_shape)
|
return output.view(output_shape)
|
||||||
|
|
||||||
@ -192,6 +191,7 @@ class QuantLinear(nn.Module):
|
|||||||
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
||||||
|
|
||||||
def forward(self, x, force_cuda=False):
|
def forward(self, x, force_cuda=False):
|
||||||
|
logger.info("calling ext_gemm_half_q_half")
|
||||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
@ -33,7 +33,7 @@ except Exception:
|
|||||||
major = 1
|
major = 1
|
||||||
|
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
# V2 = False
|
# V2 = False
|
||||||
|
Loading…
Reference in New Issue
Block a user