mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
hotfix: increase precision of GPTQ/AWQ-Marlin
Sync with upstream change that improves the precision of the 'global_reduce' algorithm from FP16 to FP32. This solves some reported generation quality issues. Upstream issue/PR: https://github.com/vllm-project/vllm/pull/6795
This commit is contained in:
parent
4b49c50f4c
commit
4f69d04c3a
@ -1,5 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def awq_marlin_repack(
|
||||||
|
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Repack AWQ parameters for GPTQ-Marlin."""
|
||||||
|
...
|
||||||
|
|
||||||
def gptq_marlin_gemm(
|
def gptq_marlin_gemm(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
@ -12,6 +18,8 @@ def gptq_marlin_gemm(
|
|||||||
size_n: int,
|
size_n: int,
|
||||||
size_k: int,
|
size_k: int,
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
|
has_zp: bool,
|
||||||
|
use_fp32_reduce: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Matrix multiplication using Marlin kernels. This is an extension of
|
Matrix multiplication using Marlin kernels. This is an extension of
|
||||||
|
@ -14,7 +14,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
torch::Tensor &g_idx, torch::Tensor &perm,
|
torch::Tensor &g_idx, torch::Tensor &perm,
|
||||||
torch::Tensor &workspace, int64_t num_bits,
|
torch::Tensor &workspace, int64_t num_bits,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
bool is_k_full, bool has_zp);
|
bool is_k_full, bool has_zp,
|
||||||
|
bool use_fp32_reduce);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor &b_meta,
|
torch::Tensor &b_meta,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -223,6 +223,7 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
A_flat.shape[1],
|
A_flat.shape[1],
|
||||||
self.is_full_k,
|
self.is_full_k,
|
||||||
self.qzeros.numel() > 0,
|
self.qzeros.numel() > 0,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user