hotfix: increase precision of GPTQ/AWQ-Marlin

Sync with upstream change that improves the precision of the
'global_reduce' algorithm from FP16 to FP32. This solves some
reported generation quality issues.

Upstream issue/PR:

https://github.com/vllm-project/vllm/pull/6795
This commit is contained in:
Daniël de Kok 2024-07-29 08:40:17 +00:00
parent 4b49c50f4c
commit 4f69d04c3a
4 changed files with 491 additions and 386 deletions

View File

@ -1,5 +1,11 @@
import torch
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
"""Repack AWQ parameters for GPTQ-Marlin."""
...
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
@ -12,6 +18,8 @@ def gptq_marlin_gemm(
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool,
use_fp32_reduce: bool,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of

View File

@ -14,7 +14,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &g_idx, torch::Tensor &perm,
torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta,

File diff suppressed because it is too large Load Diff

View File

@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module):
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))