mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
24 lines
931 B
C++
24 lines
931 B
C++
#pragma once
|
|
|
|
#include <torch/library.h>
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
// No support for async
|
|
#else
|
|
|
|
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_scales, torch::Tensor &g_idx,
|
|
torch::Tensor &perm, torch::Tensor &workspace,
|
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
int64_t size_k, bool is_k_full);
|
|
|
|
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|
int64_t size_k, int64_t size_n,
|
|
int64_t num_bits);
|
|
|
|
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
|
|
|
#endif
|