Update quantization kernels

This commit is contained in:
Daniël de Kok 2025-07-07 06:12:18 +00:00
parent 778b61c0da
commit a76ae953fe
7 changed files with 69 additions and 76 deletions

View File

@ -586,15 +586,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1747919133,
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
"lastModified": 1751868511,
"narHash": "sha256-C/TjJact7KBy88f0TK9T5PtpZQd79ak4kwGa/Ns1exM=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
"rev": "824cd64af49e1d8f939130dd1641017a5a4238a2",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "quantization-0.1.0",
"repo": "hf-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "hf-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
hf-nix.url = "github:huggingface/hf-nix";
hf-nix.url = "github:huggingface/hf-nix/quantization-0.1.0";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -33,7 +33,7 @@
};
pkgs = import nixpkgs {
inherit system;
inherit (hf-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
rust-overlay.overlays.default
hf-nix.overlays.default

View File

@ -223,82 +223,58 @@
},
{
"repo_id": "kernels-community/quantization",
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
"sha": "229f047e826202eb49dc0321bb38aed5d3ab96e3",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-f52c9b1a7cd98fb389c6d2a0b22a293cb36eb96af3a624f5aec761735861c96d",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-e5f0da343363a562ce52f147a9534cd54a3efa90e70671f606cc2516f02a3876",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-caad9300c155faf79c26426f10951ba75f931a05e741a5b39a24b064daabc040",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-4fc87893de14a29ba4b55f5026ea05ec5901c0b52abd5ebae681ea0b791e858c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-72c975ea63fc524a38fcee5b2dbdb566eff0a0ea546ee5756441d04908e4e896",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-28c5510e3b07eae2b3846b880f6111da65df024e1f24f81077d187a97c015364",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-8444cf77686578a6b0f7e2fd29bf2783ba120ebf7df41573f61d2521fd0acc10",
"hash": "sha256-354e86a4a1fc38bfaddb3bf98c083ccd8a00de721d6769e0f3c594b719c9dbd2",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-6ea8e00625b5fe799fbe407e7de0fc08228cac26f9bbed2d70a6500026fe3bab",
"hash": "sha256-99523c409552d6a0a514987bd31b427c273695abaa1085be85f9f243f6ff8184",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-0b8b8afbdaf9aa533895cb9e884e3ad3e9a34d483f05a1bbde1b8902f9dbeb0f",
"hash": "sha256-c7ed22cb6bb3cf23b3b36e157a3f902b2d22f2236a30e2e72110033aff4485c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-e115e855d7ca4b97787f04c88e128432256c6b43d4823fb8889ab9985dc4cf36",
"hash": "sha256-91498f3a73741f2e9b63467f0992fa28daabb3c0d9d06aec2fb650285fa7df92",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-509f08c48a05584cc85c058607277fcbe3193e6cc61846dd2416d39e27c1d68e",
"hash": "sha256-fcf32cbeb606021b80f3d1c86ca977a13a680fb4a7c15738487b35bc8f9edc04",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-a10236bffd435296c736ae2762ab0836da2421297e46b377368a17b39d70c27b",
"hash": "sha256-eeff3d5134795a25bb484b95a11f72658ef096766d13a126530cc379cb74850b",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-ca2cb56f3eea4c399a61e21ba9b577d718b250aa60a13f42f01019ddd5cd8b0c",
"hash": "sha256-8aaaae2f066c2b041828703d09882f80e9c058527385b0cfe256349972d12929",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-8fcd62d8243a30b63a03751cc0c15d24f6e00e43eae79f7281627f24e078bf9a",
"hash": "sha256-6556ddcd229b4532572294a1313f394a0f9f15be8d1cab1007dbc0ba712a1a94",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-60f5807ee3da937c57c1b6080c30632305aa4875ed5a52bf4e81968770b61b13",
"hash": "sha256-c5fe51f7830adc47a642151256b023fde606611e641fb12acccf9e5cc2d319e3",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-64298b1713dc1d950915dc6569a06e2f541de3ed80aa5b32084246c1fdc7a958",
"hash": "sha256-aec9c8d1e3653c700da624cedc7619af4eea77a1ba5b0f1093f7ea22d811f335",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-d9e219890dc28e8582ef21d6f81f2ebc361de218a86b742be63bc4714f102e5e",
"hash": "sha256-82623a36b6921357373ec767114438a4818a86087f56944d25fe21946b217420",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-d72549f51aefcf020bc74262bbbccb78094638c5ab9adc8667873d247c1cce86",
"hash": "sha256-d110173a26cb02d80c5462434491f30e41f66e21c3a9723f9e4edc4cf3a9bd9f",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-d31ac5f87d7c7f62c63c72946479193aed467c9417c0acead5137e0e1fa968f8",
"hash": "sha256-936c75f188ffcd8debbaffed37c73edbffaaa05462bd4d2fc78f767fc4678755",
"hash_type": "git_lfs_concat"
}
}

View File

@ -59,7 +59,7 @@ build-backend = "setuptools.build_meta"
"kernels-community/paged-attention" = ">=0.0.2"
"kernels-community/moe" = ">=0.1.1"
"kernels-community/punica-sgmv" = ">=0.0.1"
"kernels-community/quantization" = ">=0.0.3"
"kernels-community/quantization" = ">=0.1.1"
"kernels-community/quantization-eetq" = ">=0.0.1"
"kernels-community/rotary" = ">=0.0.1"

View File

@ -76,15 +76,21 @@ class GPTQMarlinFP8Linear(nn.Module):
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = quantization.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
C = quantization.gptq_marlin_gemm(
a=A_flat,
c=None,
b_q_weight=self.qweight,
b_scales=self.scales,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
workspace=self.workspace,
b_q_type=quantization.scalar_type.scalar_types.float8_e4m3fn,
size_m=A_flat.shape[0],
size_n=self.scales.shape[1],
size_k=A_flat.shape[1],
use_fp32_reduce=True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
@ -143,5 +149,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
)
scales = permute_scales(scales)
scales = quantization.marlin_utils_fp8.fp8_fused_exponent_bias_into_scales(scales)
return repacked, scales

View File

@ -256,7 +256,7 @@ class GPTQMarlinWeight(Weight):
"""
qweight: torch.Tensor
qzeros: torch.Tensor
qzeros: Optional[torch.Tensor]
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
@ -268,6 +268,7 @@ class GPTQMarlinWeight(Weight):
assert self.scales.dtype in (torch.float16, torch.bfloat16)
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
assert self.qzeros is None or self.qzeros.numel() > 0
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
@ -350,9 +351,6 @@ def repack_gptq_for_marlin(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
@ -392,7 +390,7 @@ class GPTQMarlinLinear(nn.Module):
if weight.bits not in (4, 8):
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
if weight.qzeros.numel() > 0:
if weight.qzeros is not None:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4
else:
@ -424,20 +422,21 @@ class GPTQMarlinLinear(nn.Module):
A_flat = A.view(-1, A.shape[-1])
C = quantization.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.quant_type,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
a=A_flat,
c=None,
b_q_weight=self.qweight,
b_scales=self.scales,
global_scale=None,
b_zeros=self.qzeros,
g_idx=self.g_idx,
perm=self.perm,
workspace=self.workspace,
b_q_type=self.quant_type,
size_m=A_flat.shape[0],
size_n=self.scales.shape[1],
size_k=A_flat.shape[1],
is_k_full=self.is_full_k,
use_fp32_reduce=True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))

View File

@ -202,9 +202,13 @@ def _pack_weight(
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
(n_experts,) + ((0,) if weight.qzeros is None else weight.qzeros.shape),
dtype=(
weight.qweight.dtype if weight.qzeros is None else weight.qzeros.dtype
),
device=(
weight.qweight.device if weight.qzeros is None else weight.qzeros.device
),
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
@ -232,7 +236,13 @@ def _pack_weight(
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.qzeros[expert] = (
torch.zeros(
(0,), device=moe_weight.qzeros.device, dtype=moe_weight.qzeros.dtype
)
if weight.qzeros is None
else weight.qzeros
)
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm