From 4f69d04c3aaccf2ebfecff98686f4a2c546bda45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 29 Jul 2024 08:40:17 +0000 Subject: [PATCH] hotfix: increase precision of GPTQ/AWQ-Marlin Sync with upstream change that improves the precision of the 'global_reduce' algorithm from FP16 to FP32. This solves some reported generation quality issues. Upstream issue/PR: https://github.com/vllm-project/vllm/pull/6795 --- server/marlin/marlin_kernels/__init__.pyi | 8 + server/marlin/marlin_kernels/ext.hh | 3 +- server/marlin/marlin_kernels/gptq_marlin.cu | 865 ++++++++++-------- .../layers/marlin/gptq.py | 1 + 4 files changed, 491 insertions(+), 386 deletions(-) diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 53464719..4d8c34f0 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -1,5 +1,11 @@ import torch +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + """Repack AWQ parameters for GPTQ-Marlin.""" + ... + def gptq_marlin_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, @@ -12,6 +18,8 @@ def gptq_marlin_gemm( size_n: int, size_k: int, is_k_full: bool, + has_zp: bool, + use_fp32_reduce: bool, ) -> torch.Tensor: """ Matrix multiplication using Marlin kernels. This is an extension of diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index ec76b5f4..cf7cb412 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -14,7 +14,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp); + bool is_k_full, bool has_zp, + bool use_fp32_reduce); torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_meta, diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu index 122c5c16..bc632475 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ b/server/marlin/marlin_kernels/gptq_marlin.cu @@ -22,61 +22,61 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ "only float16 and bfloat16 is supported"); -template -inline std::string str(T x) { - return std::to_string(x); -} +template inline std::string str(T x) { return std::to_string(x); } namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce ) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &b_zeros, + torch::Tensor &g_idx, torch::Tensor &perm, + torch::Tensor &workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + bool is_k_full, bool has_zp, + bool use_fp32_reduce) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -87,26 +87,24 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } @@ -115,9 +113,9 @@ __device__ inline void mma(const typename ScalarType::FragA& a_frag, // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, + const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -127,8 +125,7 @@ __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template -__device__ inline int lop3(int a, int b, int c) { +template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -173,11 +170,11 @@ __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -197,12 +194,12 @@ dequant_4bit(int q) { static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC308C308; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -229,10 +226,10 @@ __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } @@ -242,8 +239,8 @@ dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); + uint32_t *fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); @@ -256,7 +253,7 @@ dequant_8bit(int q) { fp32_intermediates[2] -= 8388736.f; fp32_intermediates[3] -= 8388736.f; - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + uint32_t *bf16_result_ptr = reinterpret_cast(&frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], @@ -273,8 +270,8 @@ __device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { } template <> -__device__ inline typename ScalarType::FragB dequant_4bit_zp( - int q) { +__device__ inline typename ScalarType::FragB +dequant_4bit_zp(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -286,11 +283,11 @@ __device__ inline typename ScalarType::FragB dequant_4bit_zp( const int MUL = 0x2c002c00; const int ADD = 0xd400d400; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -310,12 +307,12 @@ dequant_4bit_zp(int q) { static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC300C300; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -325,8 +322,8 @@ __device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { } template <> -__device__ inline typename ScalarType::FragB dequant_8bit_zp( - int q) { +__device__ inline typename ScalarType::FragB +dequant_8bit_zp(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -337,10 +334,10 @@ __device__ inline typename ScalarType::FragB dequant_8bit_zp( static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } @@ -350,8 +347,8 @@ dequant_8bit_zp(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); + uint32_t *fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); @@ -364,7 +361,7 @@ dequant_8bit_zp(int q) { fp32_intermediates[2] -= 8388608.f; fp32_intermediates[3] -= 8388608.f; - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + uint32_t *bf16_result_ptr = reinterpret_cast(&frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], @@ -376,43 +373,43 @@ dequant_8bit_zp(int q) { // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, +__device__ inline void sub_zp(typename ScalarType::FragB &frag_b, + typename ScalarType::scalar_t2 &frag_zp, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, +__device__ inline void scale4(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); @@ -420,15 +417,15 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, // Given 2 floats multiply by 2 scales (halves) template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); +__device__ inline void scale_float(float *c, + typename ScalarType::FragS &s) { + scalar_t *s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { +__device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -443,7 +440,7 @@ __device__ inline void barrier_acquire(int* lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { +__device__ inline void barrier_release(int *lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -462,9 +459,9 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; @@ -481,8 +478,9 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int offset = row * row_stride; - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); + half const *a_row_half = + reinterpret_cast(a_int4_ptr + offset); + half *out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; @@ -513,35 +511,37 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -589,11 +589,13 @@ __global__ void Marlin( int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -602,6 +604,7 @@ __global__ void Marlin( C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; } // Compute all information about the current slice which is required for @@ -609,22 +612,27 @@ __global__ void Marlin( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; + if (col_off > 0) + slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; + if (col_off > 0) + slice_idx--; } } if (slice_col == n_tiles) { @@ -632,6 +640,7 @@ __global__ void Marlin( C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; + par_id++; } }; init_slice(); @@ -776,7 +785,7 @@ __global__ void Marlin( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; @@ -794,13 +803,13 @@ __global__ void Marlin( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -810,33 +819,33 @@ __global__ void Marlin( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_g_idx = sh_b + (stages * b_sh_stage); + int4 *sh_zp = sh_g_idx + (stages * g_idx_stage); + int4 *sh_s = sh_zp + (stages * zp_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; @@ -880,18 +889,18 @@ __global__ void Marlin( // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } @@ -904,10 +913,10 @@ __global__ void Marlin( int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); + int4 const *cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], @@ -916,7 +925,7 @@ __global__ void Marlin( } } else { if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group @@ -938,7 +947,7 @@ __global__ void Marlin( } if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch zero-points if this tile starts a new group @@ -984,16 +993,16 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll +#pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -1008,8 +1017,8 @@ __global__ void Marlin( return; } - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; @@ -1025,10 +1034,10 @@ __global__ void Marlin( // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; @@ -1041,9 +1050,9 @@ __global__ void Marlin( int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; - int4* sh_s_stage = sh_s + s_sh_stage * pipe; + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } @@ -1070,7 +1079,7 @@ __global__ void Marlin( // thread-id) int warp_id = threadIdx.x / 32; int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; @@ -1078,7 +1087,7 @@ __global__ void Marlin( cur_k += warp_row * 16; int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + @@ -1086,35 +1095,35 @@ __global__ void Marlin( if (is_same_group[pipe]) { if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread + 9}; // Tensor core offsets per thread - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; @@ -1128,16 +1137,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = + int4 *sh_zp_stage = sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; } } else { int warp_id = threadIdx.x / 32; @@ -1151,13 +1160,13 @@ __global__ void Marlin( int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; sh_zp_stage += cur_group_id * zp_sh_stride; for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; } } }; @@ -1186,9 +1195,9 @@ __global__ void Marlin( frag_zp[3] = frag_zp_1[1]; } - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -1206,7 +1215,7 @@ __global__ void Marlin( } } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -1252,7 +1261,7 @@ __global__ void Marlin( } } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -1277,38 +1286,38 @@ __global__ void Marlin( // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. - #pragma unroll +#pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll +#pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -1321,7 +1330,7 @@ __global__ void Marlin( // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). @@ -1339,39 +1348,39 @@ __global__ void Marlin( int row = (threadIdx.x % 32) / 4; if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1382,6 +1391,53 @@ __global__ void Marlin( } }; + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + int par_offset = c_size * n_tiles * par_id; + int slice_offset = c_size * slice_col; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = par_offset + slice_offset; + + if (!is_th_active) { + return; + } + + if (!first) { + float *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + sh[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float *sh_c_ptr = reinterpret_cast(&sh[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. @@ -1405,7 +1461,7 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS &s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1415,13 +1471,13 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } - ((scalar_t2*)sh)[idx] = res; + ((scalar_t2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -1438,7 +1494,7 @@ __global__ void Marlin( } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -1453,7 +1509,7 @@ __global__ void Marlin( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - #pragma unroll +#pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; @@ -1491,9 +1547,9 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < stages;) { - #pragma unroll +#pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); @@ -1560,8 +1616,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { @@ -1569,8 +1625,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } @@ -1581,35 +1637,39 @@ __global__ void Marlin( // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) { scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), + reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), + reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), + reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -1618,12 +1678,13 @@ __global__ void Marlin( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading @@ -1644,26 +1705,24 @@ __global__ void Marlin( } } - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ - prob_m, prob_n, prob_k, locks); \ - } +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + } typedef struct { int thread_k; @@ -1695,7 +1754,7 @@ thread_config_t large_batch_thread_configs[] = { }; -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; @@ -1708,15 +1767,15 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { @@ -1726,7 +1785,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, +bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; @@ -1757,12 +1816,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { @@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, return true; } +int determine_reduce_max_m(int prob_m, int max_par) { + constexpr int tile_m_size = 16; + + if (prob_m <= tile_m_size) { + return tile_m_size; + + } else if (prob_m <= tile_m_size * 2) { + return tile_m_size * 2; + + } else if (prob_m <= tile_m_size * 3) { + return tile_m_size * 3; + + } else if (prob_m <= tile_m_size * 4) { + return tile_m_size * 4; + + } else { + int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par); + return tile_m_size * 4 * cur_par; + } +} + exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, @@ -1825,68 +1905,68 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } return exec_config_t{0, {-1, -1, -1}}; } - #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) +#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, - void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int dev, +void marlin_mm_f16i4(const void *A, const void *B, void *C, void *C_tmp, + void *s, void *zp, void *g_idx, void *perm, void *a_tmp, + int prob_m, int prob_n, int prob_k, void *workspace, + int num_bits, bool has_act_order, bool is_k_full, + bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { + int max_par, bool use_fp32_reduce) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -1967,16 +2047,17 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, } } - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - const int4* zp_ptr = (const int4*)zp; - const int* g_idx_ptr = (const int*)g_idx; - const int* perm_ptr = (const int*)perm; - int4* a_tmp_ptr = (int4*)a_tmp; + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + const int4 *s_ptr = (const int4 *)s; + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; - int* locks = (int*)workspace; + int *locks = (int *)workspace; if (has_act_order) { // Permute A columns @@ -2002,7 +2083,8 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; + if (par > max_par) + par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; @@ -2042,14 +2124,15 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, } } -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &b_zeros, + torch::Tensor &g_idx, torch::Tensor &perm, + torch::Tensor &workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { + bool is_k_full, bool has_zp, + bool use_fp32_reduce) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -2099,6 +2182,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + // Alloc C tmp buffer that is going to be used for the global reduce + int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par); + int reduce_n = size_n; + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (!use_fp32_reduce) { + reduce_max_m = 0; + reduce_n = 0; + } + torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; @@ -2171,20 +2265,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, if (a.scalar_type() == at::ScalarType::Half) { marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + c_tmp.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, + c.data_ptr(), c_tmp.data_ptr(), + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 697634ea..80674b14 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module): A_flat.shape[1], self.is_full_k, self.qzeros.numel() > 0, + True, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))