hotfix: increase precision of GPTQ/AWQ-Marlin

Sync with upstream change that improves the precision of the
'global_reduce' algorithm from FP16 to FP32. This solves some
reported generation quality issues.

Upstream issue/PR:

https://github.com/vllm-project/vllm/pull/6795
This commit is contained in:
Daniël de Kok 2024-07-29 08:40:17 +00:00
parent 4b49c50f4c
commit 4f69d04c3a
4 changed files with 491 additions and 386 deletions

View File

@ -1,5 +1,11 @@
import torch import torch
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
"""Repack AWQ parameters for GPTQ-Marlin."""
...
def gptq_marlin_gemm( def gptq_marlin_gemm(
a: torch.Tensor, a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
@ -12,6 +18,8 @@ def gptq_marlin_gemm(
size_n: int, size_n: int,
size_k: int, size_k: int,
is_k_full: bool, is_k_full: bool,
has_zp: bool,
use_fp32_reduce: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Matrix multiplication using Marlin kernels. This is an extension of Matrix multiplication using Marlin kernels. This is an extension of

View File

@ -14,7 +14,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &g_idx, torch::Tensor &perm,
torch::Tensor &workspace, int64_t num_bits, torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp); bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta, torch::Tensor &b_meta,

View File

@ -27,10 +27,7 @@
std::is_same<scalar_t, nv_bfloat16>::value, \ std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported"); "only float16 and bfloat16 is supported");
template <typename T> template <typename T> inline std::string str(T x) { return std::to_string(x); }
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin { namespace marlin {
@ -55,10 +52,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
__global__ void Marlin( __global__ void
const int4* __restrict__ A, // fp16 input matrix of shape mxk Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4 *__restrict__ C, // fp16 output buffer of shape mxn
int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int *__restrict__ g_idx, // int32 group indices of shape k const int *__restrict__ g_idx, // int32 group indices of shape k
@ -66,7 +64,8 @@ __global__ void Marlin(
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization int *locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {} ) {}
} // namespace gptq_marlin } // namespace gptq_marlin
@ -76,7 +75,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &g_idx, torch::Tensor &perm,
torch::Tensor &workspace, int64_t num_bits, torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) { bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"); "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
@ -94,19 +94,17 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b); const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
float *c = reinterpret_cast<float *>(&frag_c); float *c = reinterpret_cast<float *>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) { if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else { } else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
} }
@ -127,8 +125,7 @@ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> template <int lut> __device__ inline int lop3(int a, int b, int c) {
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)
@ -273,8 +270,8 @@ __device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
} }
template <> template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>( __device__ inline typename ScalarType<half>::FragB
int q) { dequant_4bit_zp<half>(int q) {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
@ -325,8 +322,8 @@ __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
} }
template <> template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>( __device__ inline typename ScalarType<half>::FragB
int q) { dequant_8bit_zp<half>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464; static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@ -481,7 +478,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int offset = row * row_stride; int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset); half const *a_row_half =
reinterpret_cast<half const *>(a_int4_ptr + offset);
half *out_half = reinterpret_cast<half *>(out_int4_ptr + offset); half *out_half = reinterpret_cast<half *>(out_int4_ptr + offset);
int base_k = 0; int base_k = 0;
@ -528,10 +526,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
__global__ void Marlin( __global__ void
const int4* __restrict__ A, // fp16 input matrix of shape mxk Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4 *__restrict__ C, // fp16 output buffer of shape mxn
int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape
@ -541,7 +540,8 @@ __global__ void Marlin(
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization int *locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) { ) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
@ -595,6 +595,8 @@ __global__ void Marlin(
int slice_idx; // index of threadblock in current slice; numbered bottom to int slice_idx; // index of threadblock in current slice; numbered bottom to
// top // top
int par_id = 0;
// We can easily implement parallel problem execution by just remapping // We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers // indices and advancing global pointers
if (slice_col_par >= n_tiles) { if (slice_col_par >= n_tiles) {
@ -602,6 +604,7 @@ __global__ void Marlin(
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles; locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles; slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
} }
// Compute all information about the current slice which is required for // Compute all information about the current slice which is required for
@ -609,22 +612,27 @@ __global__ void Marlin(
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
if (slice_iters == 0) return; slice_iters = 0;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; if (slice_iters == 0)
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters); slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) slice_count++; if (col_off > 0)
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--; if (col_off > 0)
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -632,6 +640,7 @@ __global__ void Marlin(
C += 16 * thread_m_blocks * prob_n / 8; C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles; locks += n_tiles;
slice_col = 0; slice_col = 0;
par_id++;
} }
}; };
init_slice(); init_slice();
@ -1287,8 +1296,8 @@ __global__ void Marlin(
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float* c_rd = float *c_rd = reinterpret_cast<float *>(
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]); &sh[red_sh_delta * j + red_sh_rd]);
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]); float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
@ -1321,7 +1330,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped // finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are // partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute). // results in FP16 (but still reduce with FP32 compute).
@ -1344,11 +1353,11 @@ __global__ void Marlin(
// though these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred( cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)], c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); i < (thread_m_blocks - 1) * 4 ||
8 * (i / 2) + row < prob_m);
} }
cp_async_fence(); cp_async_fence();
cp_async_wait<0>(); cp_async_wait<0>();
@ -1382,6 +1391,53 @@ __global__ void Marlin(
} }
}; };
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
int par_offset = c_size * n_tiles * par_id;
int slice_offset = c_size * slice_col;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = par_offset + slice_offset;
if (!is_th_active) {
return;
}
if (!first) {
float *frag_c_ptr = reinterpret_cast<float *>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float *sh_c_ptr = reinterpret_cast<float *>(&sh[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4 *frag_c_ptr = reinterpret_cast<int4 *>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually // Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed // reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout. // in fragment layout.
@ -1606,7 +1662,11 @@ __global__ void Marlin(
if (slice_count > 1) { // only globally reduce if there is more than one if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice // block in a slice
barrier_acquire(&locks[slice_col], slice_idx); barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last); if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[slice_col], last); barrier_release(&locks[slice_col], last);
} }
if (last) // only the last block in a slice actually writes the result if (last) // only the last block in a slice actually writes the result
@ -1623,7 +1683,8 @@ __global__ void Marlin(
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] -= b_gl_stride;
} }
// Update slice k/n for scales loading // Update slice k/n for scales loading
@ -1644,9 +1705,8 @@ __global__ void Marlin(
} }
} }
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
@ -1657,12 +1717,11 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \ HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
prob_m, prob_n, prob_k, locks); \ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
} }
typedef struct { typedef struct {
@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return true; return true;
} }
int determine_reduce_max_m(int prob_m, int max_par) {
constexpr int tile_m_size = 16;
if (prob_m <= tile_m_size) {
return tile_m_size;
} else if (prob_m <= tile_m_size * 2) {
return tile_m_size * 2;
} else if (prob_m <= tile_m_size * 3) {
return tile_m_size * 3;
} else if (prob_m <= tile_m_size * 4) {
return tile_m_size * 4;
} else {
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
return tile_m_size * 4 * cur_par;
}
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, int num_bits, int group_size,
bool has_act_order, bool is_k_full, bool has_act_order, bool is_k_full,
@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t> template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, void marlin_mm_f16i4(const void *A, const void *B, void *C, void *C_tmp,
void* g_idx, void* perm, void* a_tmp, int prob_m, void *s, void *zp, void *g_idx, void *perm, void *a_tmp,
int prob_n, int prob_k, void* workspace, int num_bits, int prob_m, int prob_n, int prob_k, void *workspace,
bool has_act_order, bool is_k_full, bool has_zp, int num_bits, bool has_act_order, bool is_k_full,
int num_groups, int group_size, int dev, bool has_zp, int num_groups, int group_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms, cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par) { int max_par, bool use_fp32_reduce) {
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const int4 *A_ptr = (const int4 *)A; const int4 *A_ptr = (const int4 *)A;
const int4 *B_ptr = (const int4 *)B; const int4 *B_ptr = (const int4 *)B;
int4 *C_ptr = (int4 *)C; int4 *C_ptr = (int4 *)C;
int4 *C_tmp_ptr = (int4 *)C_tmp;
const int4 *s_ptr = (const int4 *)s; const int4 *s_ptr = (const int4 *)s;
const int4 *zp_ptr = (const int4 *)zp; const int4 *zp_ptr = (const int4 *)zp;
const int *g_idx_ptr = (const int *)g_idx; const int *g_idx_ptr = (const int *)g_idx;
@ -2002,7 +2083,8 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par) par = max_par; if (par > max_par)
par = max_par;
prob_m = (16 * exec_cfg.max_m_blocks) * par; prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += exec_cfg.max_m_blocks * (par - 1); i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = exec_cfg.max_m_blocks; thread_m_blocks = exec_cfg.max_m_blocks;
@ -2049,7 +2131,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &g_idx, torch::Tensor &perm,
torch::Tensor &workspace, int64_t num_bits, torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) { bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
// Verify num_bits // Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
@ -2099,6 +2182,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
// Alloc C tmp buffer that is going to be used for the global reduce
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
int reduce_n = size_n;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (!use_fp32_reduce) {
reduce_max_m = 0;
reduce_n = 0;
}
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1) // auto -1)
int thread_k = -1; int thread_k = -1;
@ -2171,20 +2265,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
marlin::marlin_mm_f16i4<half>( marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par); thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm_f16i4<nv_bfloat16>( marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k, perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par); thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else { } else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }

View File

@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module):
A_flat.shape[1], A_flat.shape[1],
self.is_full_k, self.is_full_k,
self.qzeros.numel() > 0, self.qzeros.numel() > 0,
True,
) )
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))