mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
77 lines
2.0 KiB
Plaintext
77 lines
2.0 KiB
Plaintext
#pragma once
|
|
|
|
#include <torch/all.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <cuda.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime.h>
|
|
#include <iostream>
|
|
|
|
namespace gptq_marlin {
|
|
|
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
|
// we want relatively few warps to have many registers per warp and small tiles.
|
|
static constexpr int default_threads = 256;
|
|
|
|
static constexpr int pipe_stages =
|
|
4; // 4 pipeline stages fit into shared memory
|
|
|
|
static constexpr int min_thread_n = 64;
|
|
static constexpr int min_thread_k = 64;
|
|
|
|
static constexpr int tile_size = 16;
|
|
static constexpr int max_par = 16;
|
|
|
|
template <typename T, int n>
|
|
struct Vec {
|
|
T elems[n];
|
|
__device__ T& operator[](int i) { return elems[i]; }
|
|
};
|
|
|
|
using I4 = Vec<int, 4>;
|
|
|
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
// No support for async
|
|
#else
|
|
|
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
|
bool pred = true) {
|
|
const int BYTES = 16;
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
asm volatile(
|
|
"{\n"
|
|
" .reg .pred p;\n"
|
|
" setp.ne.b32 p, %0, 0;\n"
|
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
"}\n" ::"r"((int)pred),
|
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
|
}
|
|
|
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
|
const int BYTES = 16;
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
asm volatile(
|
|
"{\n"
|
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
"}\n" ::"r"(smem),
|
|
"l"(glob_ptr), "n"(BYTES));
|
|
}
|
|
|
|
__device__ inline void cp_async_fence() {
|
|
asm volatile("cp.async.commit_group;\n" ::);
|
|
}
|
|
|
|
template <int n>
|
|
__device__ inline void cp_async_wait() {
|
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
|
}
|
|
|
|
#endif
|
|
|
|
} // namespace gptq_marlin
|