mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 04:22:08 +00:00
* Add support for repacking AWQ weights for GPTQ-Marlin So far we couldn't support AWQ because virtually all AWQ models use symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin has recently added support AWQ repacking and AWQ asymmetric quantization (zero_point=True). This change updates all GPTQ-Marlin kernels from upstream and wires up AWQ support. For now enabling AWQ using Marlin requires running TGI with `--quantize gptq`. * Enable Marlin for supported AWQ configurations by default This makes the AWQ -> GPTQ repack test redundant, since we are now testing this with the regular AWQ test.
40 lines
1.8 KiB
C++
40 lines
1.8 KiB
C++
#pragma once
|
|
|
|
#include <torch/library.h>
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
// No support for async
|
|
#else
|
|
|
|
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
|
|
int64_t size_n, int64_t num_bits);
|
|
|
|
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_scales, torch::Tensor &b_zeros,
|
|
torch::Tensor &g_idx, torch::Tensor &perm,
|
|
torch::Tensor &workspace, int64_t num_bits,
|
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
|
bool is_k_full, bool has_zp);
|
|
|
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_meta,
|
|
torch::Tensor &b_scales,
|
|
torch::Tensor &workspace, int64_t num_bits,
|
|
int64_t size_m, int64_t size_n,
|
|
int64_t size_k);
|
|
|
|
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|
int64_t size_k, int64_t size_n,
|
|
int64_t num_bits);
|
|
|
|
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
|
|
|
torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
int64_t size_k);
|
|
|
|
#endif
|