Add support for compressed-tensors w8a8 int checkpoints

This change adds a loader for w8a8 int checkpoints. One large benefit of
int8 support is that the corresponding cutlass matmul kernels also work on
compute capability 7.5.

Evaluation on neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8:

|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8431|±  |0.0100|
|               |       |strict-match    |     8|exact_match            |↑  |0.8393|±  |0.0101|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8597|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8201|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7967|±  |0.0173|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7468|±  |0.0187|

Which is the same ballpark as vLLM.

As usual, lots of thanks to Neural Magic/vLLM for the kernels.
This commit is contained in:
Daniël de Kok 2024-11-14 11:00:29 +00:00
parent 52e48739a5
commit b2dc10aea5
14 changed files with 1803 additions and 110 deletions

View File

@ -2,16 +2,10 @@
"nodes": {
"cachix": {
"inputs": {
"devenv": [
"crate2nix"
],
"flake-compat": [
"crate2nix"
],
"devenv": ["crate2nix"],
"flake-compat": ["crate2nix"],
"nixpkgs": "nixpkgs",
"pre-commit-hooks": [
"crate2nix"
]
"pre-commit-hooks": ["crate2nix"]
},
"locked": {
"lastModified": 1709700175,
@ -30,19 +24,10 @@
},
"cachix_2": {
"inputs": {
"devenv": [
"crate2nix",
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable"
],
"devenv": ["crate2nix", "crate2nix_stable"],
"flake-compat": ["crate2nix", "crate2nix_stable"],
"nixpkgs": "nixpkgs_2",
"pre-commit-hooks": [
"crate2nix",
"crate2nix_stable"
]
"pre-commit-hooks": ["crate2nix", "crate2nix_stable"]
},
"locked": {
"lastModified": 1716549461,
@ -61,16 +46,8 @@
},
"cachix_3": {
"inputs": {
"devenv": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
],
"devenv": ["crate2nix", "crate2nix_stable", "crate2nix_stable"],
"flake-compat": ["crate2nix", "crate2nix_stable", "crate2nix_stable"],
"nixpkgs": "nixpkgs_3",
"pre-commit-hooks": [
"crate2nix",
@ -101,10 +78,7 @@
"flake-compat": "flake-compat_3",
"flake-parts": "flake-parts_3",
"nix-test-runner": "nix-test-runner_3",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"nixpkgs": ["tgi-nix", "nixpkgs"],
"pre-commit-hooks": "pre-commit-hooks_3"
},
"locked": {
@ -219,11 +193,7 @@
"devshell_2": {
"inputs": {
"flake-utils": "flake-utils_3",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
"nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"]
},
"locked": {
"lastModified": 1717408969,
@ -242,10 +212,7 @@
"devshell_3": {
"inputs": {
"flake-utils": "flake-utils_4",
"nixpkgs": [
"crate2nix",
"nixpkgs"
]
"nixpkgs": ["crate2nix", "nixpkgs"]
},
"locked": {
"lastModified": 1711099426,
@ -343,11 +310,7 @@
},
"flake-parts_2": {
"inputs": {
"nixpkgs-lib": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
"nixpkgs-lib": ["crate2nix", "crate2nix_stable", "nixpkgs"]
},
"locked": {
"lastModified": 1719745305,
@ -365,10 +328,7 @@
},
"flake-parts_3": {
"inputs": {
"nixpkgs-lib": [
"crate2nix",
"nixpkgs"
]
"nixpkgs-lib": ["crate2nix", "nixpkgs"]
},
"locked": {
"lastModified": 1712014858,
@ -559,11 +519,7 @@
},
"gitignore_3": {
"inputs": {
"nixpkgs": [
"crate2nix",
"pre-commit-hooks",
"nixpkgs"
]
"nixpkgs": ["crate2nix", "pre-commit-hooks", "nixpkgs"]
},
"locked": {
"lastModified": 1709087332,
@ -770,22 +726,10 @@
},
"pre-commit-hooks_2": {
"inputs": {
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"flake-compat"
],
"flake-compat": ["crate2nix", "crate2nix_stable", "flake-compat"],
"gitignore": "gitignore_2",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
"nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"],
"nixpkgs-stable": ["crate2nix", "crate2nix_stable", "nixpkgs"]
},
"locked": {
"lastModified": 1719259945,
@ -803,20 +747,11 @@
},
"pre-commit-hooks_3": {
"inputs": {
"flake-compat": [
"crate2nix",
"flake-compat"
],
"flake-compat": ["crate2nix", "flake-compat"],
"flake-utils": "flake-utils_5",
"gitignore": "gitignore_3",
"nixpkgs": [
"crate2nix",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"nixpkgs"
]
"nixpkgs": ["crate2nix", "nixpkgs"],
"nixpkgs-stable": ["crate2nix", "nixpkgs"]
},
"locked": {
"lastModified": 1712055707,
@ -837,20 +772,14 @@
"crate2nix": "crate2nix",
"flake-utils": "flake-utils_6",
"nix-filter": "nix-filter",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"nixpkgs": ["tgi-nix", "nixpkgs"],
"rust-overlay": "rust-overlay",
"tgi-nix": "tgi-nix"
}
},
"rust-overlay": {
"inputs": {
"nixpkgs": [
"tgi-nix",
"nixpkgs"
]
"nixpkgs": ["tgi-nix", "nixpkgs"]
},
"locked": {
"lastModified": 1729045942,

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.4902344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 323,
"logprob": -1.1171875,
"special": false,
"text": " and"
},
{
"id": 1268,
"logprob": -0.9477539,
"special": false,
"text": " how"
},
{
"id": 1587,
"logprob": -0.51464844,
"special": false,
"text": " does"
},
{
"id": 433,
"logprob": -0.043182373,
"special": false,
"text": " it"
},
{
"id": 1782,
"logprob": -1.0810547,
"special": false,
"text": " differ"
},
{
"id": 505,
"logprob": -0.005054474,
"special": false,
"text": " from"
},
{
"id": 8776,
"logprob": -0.47485352,
"special": false,
"text": " traditional"
},
{
"id": 5780,
"logprob": -0.15112305,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0011291504,
"special": false,
"text": " learning"
},
{
"id": 5380,
"logprob": -0.31323242,
"special": false,
"text": "?\n"
}
],
"top_tokens": null
},
"generated_text": " and how does it differ from traditional machine learning?\n"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
}
],
"seed": 0,
"tokens": [
{
"id": 5380,
"logprob": 0.0,
"special": false,
"text": "?\n"
},
{
"id": 34564,
"logprob": 0.0,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": 0.0,
"special": false,
"text": " learning"
},
{
"id": 11,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 1101,
"logprob": -1.0947266,
"special": false,
"text": " also"
},
{
"id": 3967,
"logprob": 0.0,
"special": false,
"text": " known"
},
{
"id": 439,
"logprob": 0.0,
"special": false,
"text": " as"
},
{
"id": 30828,
"logprob": 0.0,
"special": false,
"text": " neural"
},
{
"id": 4009,
"logprob": -0.15563965,
"special": false,
"text": " network"
},
{
"id": 477,
"logprob": -1.4003906,
"special": false,
"text": " or"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning?\nDeep learning, also known as neural network or"
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.4902344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 323,
"logprob": -1.1171875,
"special": false,
"text": " and"
},
{
"id": 1268,
"logprob": -0.9477539,
"special": false,
"text": " how"
},
{
"id": 1587,
"logprob": -0.51464844,
"special": false,
"text": " does"
},
{
"id": 433,
"logprob": -0.043182373,
"special": false,
"text": " it"
},
{
"id": 1782,
"logprob": -1.0810547,
"special": false,
"text": " differ"
},
{
"id": 505,
"logprob": -0.005054474,
"special": false,
"text": " from"
},
{
"id": 8776,
"logprob": -0.47485352,
"special": false,
"text": " traditional"
},
{
"id": 5780,
"logprob": -0.15112305,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0011291504,
"special": false,
"text": " learning"
},
{
"id": 5380,
"logprob": -0.3173828,
"special": false,
"text": "?\n"
}
],
"top_tokens": null
},
"generated_text": " and how does it differ from traditional machine learning?\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.4902344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 323,
"logprob": -1.1171875,
"special": false,
"text": " and"
},
{
"id": 1268,
"logprob": -0.9477539,
"special": false,
"text": " how"
},
{
"id": 1587,
"logprob": -0.51464844,
"special": false,
"text": " does"
},
{
"id": 433,
"logprob": -0.043182373,
"special": false,
"text": " it"
},
{
"id": 1782,
"logprob": -1.0810547,
"special": false,
"text": " differ"
},
{
"id": 505,
"logprob": -0.005054474,
"special": false,
"text": " from"
},
{
"id": 8776,
"logprob": -0.47485352,
"special": false,
"text": " traditional"
},
{
"id": 5780,
"logprob": -0.15112305,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0011291504,
"special": false,
"text": " learning"
},
{
"id": 5380,
"logprob": -0.3173828,
"special": false,
"text": "?\n"
}
],
"top_tokens": null
},
"generated_text": " and how does it differ from traditional machine learning?\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.4902344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 323,
"logprob": -1.1171875,
"special": false,
"text": " and"
},
{
"id": 1268,
"logprob": -0.9477539,
"special": false,
"text": " how"
},
{
"id": 1587,
"logprob": -0.51464844,
"special": false,
"text": " does"
},
{
"id": 433,
"logprob": -0.043182373,
"special": false,
"text": " it"
},
{
"id": 1782,
"logprob": -1.0810547,
"special": false,
"text": " differ"
},
{
"id": 505,
"logprob": -0.005054474,
"special": false,
"text": " from"
},
{
"id": 8776,
"logprob": -0.47485352,
"special": false,
"text": " traditional"
},
{
"id": 5780,
"logprob": -0.15112305,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0011291504,
"special": false,
"text": " learning"
},
{
"id": 5380,
"logprob": -0.3173828,
"special": false,
"text": "?\n"
}
],
"top_tokens": null
},
"generated_text": " and how does it differ from traditional machine learning?\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -6.3867188,
"text": "What"
},
{
"id": 374,
"logprob": -1.1318359,
"text": " is"
},
{
"id": 5655,
"logprob": -9.6875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.3007812,
"text": " learning"
},
{
"id": 30,
"logprob": -2.4902344,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 323,
"logprob": -1.1171875,
"special": false,
"text": " and"
},
{
"id": 1268,
"logprob": -0.9477539,
"special": false,
"text": " how"
},
{
"id": 1587,
"logprob": -0.51464844,
"special": false,
"text": " does"
},
{
"id": 433,
"logprob": -0.043182373,
"special": false,
"text": " it"
},
{
"id": 1782,
"logprob": -1.0810547,
"special": false,
"text": " differ"
},
{
"id": 505,
"logprob": -0.005054474,
"special": false,
"text": " from"
},
{
"id": 8776,
"logprob": -0.47485352,
"special": false,
"text": " traditional"
},
{
"id": 5780,
"logprob": -0.15112305,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0011291504,
"special": false,
"text": " learning"
},
{
"id": 5380,
"logprob": -0.3173828,
"special": false,
"text": "?\n"
}
],
"top_tokens": null
},
"generated_text": " and how does it differ from traditional machine learning?\n"
}
]

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
},
{
"id": 30,
"logprob": -1.5,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18183,
"logprob": -1.6669922,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.08959961,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.14685059,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.125,
"special": false,
"text": " a"
},
{
"id": 25993,
"logprob": -0.81640625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.0013418198,
"special": false,
"text": " of"
},
{
"id": 5662,
"logprob": -0.16027832,
"special": false,
"text": " machine"
},
{
"id": 6832,
"logprob": -0.0016393661,
"special": false,
"text": " learning"
},
{
"id": 429,
"logprob": -0.4477539,
"special": false,
"text": " that"
},
{
"id": 5711,
"logprob": -1.2802734,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that uses"
}

View File

@ -0,0 +1,94 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
}
],
"seed": 0,
"tokens": [
{
"id": 1939,
"logprob": -2.2675781,
"special": false,
"text": "?\n\n"
},
{
"id": 33464,
"logprob": 0.0,
"special": false,
"text": "Deep"
},
{
"id": 20909,
"logprob": -0.37695312,
"special": false,
"text": " Learning"
},
{
"id": 4102,
"logprob": -1.9316406,
"special": false,
"text": " "
},
{
"id": 285,
"logprob": 0.0,
"special": false,
"text": "is"
},
{
"id": 458,
"logprob": -0.80859375,
"special": false,
"text": " an"
},
{
"id": 3082,
"logprob": -1.4541016,
"special": false,
"text": " area"
},
{
"id": 315,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 20443,
"logprob": -0.5136719,
"special": false,
"text": " artificial"
},
{
"id": 11229,
"logprob": 0.0,
"special": false,
"text": " intelligence"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence"
}

View File

@ -0,0 +1,398 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
},
{
"id": 30,
"logprob": -1.5,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18183,
"logprob": -1.6669922,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.08959961,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.14685059,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.125,
"special": false,
"text": " a"
},
{
"id": 25993,
"logprob": -0.81640625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.0013418198,
"special": false,
"text": " of"
},
{
"id": 5662,
"logprob": -0.16259766,
"special": false,
"text": " machine"
},
{
"id": 6832,
"logprob": -0.0016393661,
"special": false,
"text": " learning"
},
{
"id": 429,
"logprob": -0.4477539,
"special": false,
"text": " that"
},
{
"id": 5711,
"logprob": -1.2802734,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
},
{
"id": 30,
"logprob": -1.5,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18183,
"logprob": -1.6669922,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.08959961,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.14685059,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.125,
"special": false,
"text": " a"
},
{
"id": 25993,
"logprob": -0.81640625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.0013418198,
"special": false,
"text": " of"
},
{
"id": 5662,
"logprob": -0.16259766,
"special": false,
"text": " machine"
},
{
"id": 6832,
"logprob": -0.0016393661,
"special": false,
"text": " learning"
},
{
"id": 429,
"logprob": -0.4477539,
"special": false,
"text": " that"
},
{
"id": 5711,
"logprob": -1.2802734,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
},
{
"id": 30,
"logprob": -1.5,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18183,
"logprob": -1.6669922,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.08959961,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.14685059,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.125,
"special": false,
"text": " a"
},
{
"id": 25993,
"logprob": -0.81640625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.0013418198,
"special": false,
"text": " of"
},
{
"id": 5662,
"logprob": -0.16259766,
"special": false,
"text": " machine"
},
{
"id": 6832,
"logprob": -0.0016393661,
"special": false,
"text": " learning"
},
{
"id": 429,
"logprob": -0.4477539,
"special": false,
"text": " that"
},
{
"id": 5711,
"logprob": -1.2802734,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that uses"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 3838,
"logprob": null,
"text": "What"
},
{
"id": 374,
"logprob": -8.59375,
"text": " is"
},
{
"id": 5538,
"logprob": -10.921875,
"text": " deep"
},
{
"id": 6832,
"logprob": -0.56347656,
"text": " learning"
},
{
"id": 30,
"logprob": -1.5,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18183,
"logprob": -1.6669922,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.08959961,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.14685059,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.125,
"special": false,
"text": " a"
},
{
"id": 25993,
"logprob": -0.81640625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.0013418198,
"special": false,
"text": " of"
},
{
"id": 5662,
"logprob": -0.16259766,
"special": false,
"text": " machine"
},
{
"id": 6832,
"logprob": -0.0016393661,
"special": false,
"text": " learning"
},
{
"id": 429,
"logprob": -0.4477539,
"special": false,
"text": " that"
},
{
"id": 5711,
"logprob": -1.2802734,
"special": false,
"text": " uses"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that uses"
}
]

View File

@ -0,0 +1,88 @@
import pytest
@pytest.fixture(scope="module")
def compressed_tensors_w8a8_int_handle(launcher):
with launcher(
"neuralmagic/Llama-3.2-3B-Instruct-quantized.w8a8",
num_shard=2,
quantize="compressed-tensors",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def compressed_tensors_w8a8_int(compressed_tensors_w8a8_int_handle):
await compressed_tensors_w8a8_int_handle.health(300)
return compressed_tensors_w8a8_int_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_w8a8_int(
compressed_tensors_w8a8_int, response_snapshot
):
response = await compressed_tensors_w8a8_int.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert (
response.generated_text
== " and how does it differ from traditional machine learning?\n"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_compressed_tensors_w8a8_int_all_params(
compressed_tensors_w8a8_int, response_snapshot
):
response = await compressed_tensors_w8a8_int.generate(
"What is deep learning",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep learning, also known as neural network or"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_w8a8_int_load(
compressed_tensors_w8a8_int, generate_load, response_snapshot
):
responses = await generate_load(
compressed_tensors_w8a8_int,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
assert (
responses[0].generated_text
== " and how does it differ from traditional machine learning?\n"
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,90 @@
import pytest
@pytest.fixture(scope="module")
def compressed_tensors_w8a8_int_dynamic_weight_handle(launcher):
with launcher(
"danieldk/Qwen2.5-1.5B-Instruct-w8a8-int-dynamic-weight",
num_shard=2,
quantize="compressed-tensors",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def compressed_tensors_w8a8_int_dynamic_weight(
compressed_tensors_w8a8_int_dynamic_weight_handle,
):
await compressed_tensors_w8a8_int_dynamic_weight_handle.health(300)
return compressed_tensors_w8a8_int_dynamic_weight_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_w8a8_int_dynamic_weight(
compressed_tensors_w8a8_int_dynamic_weight, response_snapshot
):
response = await compressed_tensors_w8a8_int_dynamic_weight.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert (
response.generated_text
== " Deep learning is a subset of machine learning that uses"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(
compressed_tensors_w8a8_int_dynamic_weight, response_snapshot
):
response = await compressed_tensors_w8a8_int_dynamic_weight.generate(
"What is deep learning",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\n\nDeep Learning is an area of artificial intelligence"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_compressed_tensors_w8a8_int_dynamic_weight_load(
compressed_tensors_w8a8_int_dynamic_weight, generate_load, response_snapshot
):
responses = await generate_load(
compressed_tensors_w8a8_int_dynamic_weight,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that uses"
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

24
server/poetry.lock generated
View File

@ -1288,12 +1288,12 @@ files = [
[[package]]
name = "marlin-kernels"
version = "0.3.1"
version = "0.3.3"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"},
{file = "marlin_kernels-0.3.3+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:36da87e2d486083147c16845f193a438602a0dbc7a0ffb908fbec416c05c5951"},
]
[package.dependencies]
@ -1301,16 +1301,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.1"
version = "0.3.3"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"},
{file = "marlin_kernels-0.3.3+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:c9ca340acdf27df009bf23ee1f37978f999b1a1378736dc3306df27eb48e364d"},
]
[package.dependencies]
@ -1318,16 +1318,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.1"
version = "0.3.3"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"},
{file = "marlin_kernels-0.3.3+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:21e36b4880fc630882c8265e0cd27b379e40b1b87512f92a321506f4e5397d26"},
]
[package.dependencies]
@ -1335,16 +1335,16 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]]
name = "marlin-kernels"
version = "0.3.1"
version = "0.3.3"
description = "Marlin quantization kernels"
optional = true
python-versions = ">=3.7"
files = [
{file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"},
{file = "marlin_kernels-0.3.3+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:cdbf46e68313f76e9648ce7255353763cadbe14b7a789e01f5d502b76d64ee35"},
]
[package.dependencies]
@ -1352,7 +1352,7 @@ torch = "*"
[package.source]
type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]]
name = "mdurl"

View File

@ -48,10 +48,10 @@ attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.3/marlin_kernels-0.3.3+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -12,6 +12,7 @@ from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
@ -151,6 +152,17 @@ class CompressedTensorsLoader(WeightsLoader):
):
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
return WNA16Loader(weights)
elif (
format
in {
CompressionFormat.int_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits == 8
):
return W8A8IntLoader(input_args=input_activations, weight_args=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"

View File

@ -0,0 +1,359 @@
from typing import List, Optional, Union, TypeVar
from dataclasses import dataclass
from loguru import logger
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
class W8A8IntLoader(WeightsLoader):
"""
Loader for w8a8 integer compressed-tensors parameters.
"""
def __init__(
self,
*,
input_args: Optional[QuantizationArgs],
weight_args: QuantizationArgs,
):
if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8:
raise ValueError(
f"{type(self).__name__} only supports w8a8 int checkpoints"
)
if not weight_args.symmetric:
raise ValueError("Checkpoints with asymmetric weights are not supported")
self.load_weight_scale = not weight_args.dynamic
if input_args is not None:
static = not input_args.dynamic
symmetric = input_args.symmetric
self.load_input_scale = static
self.load_input_zero_point = static and not symmetric
self.input_symmetric = input_args.symmetric
if static:
# People shouldn't really use static input quantization,
# the output is pretty bad.
log_once(
logger.warning,
"Using W8A8 int with static input quantization results in large regressions in accuracy. Consider dynamic input quantization instead.",
)
else:
self.load_input_scale = False
self.load_input_zero_point = False
self.input_symmetric = True
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
def symmetric_to_sting(symmetric):
return "symmetric" if symmetric else "asymmetric"
return f"{self.__class__.__name__} (w8a8 int, input: {symmetric_to_sting(self.input_symmetric)}, {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes
]
shapes = [x.shape for x in w]
w = torch.cat(w, dim=dim)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False) for p in prefixes
]
input_scale = torch.cat(input_scale, dim=0)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = [
_get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
)
for prefix in prefixes
]
input_zero_point = torch.cat(input_zero_point, dim=0)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
OtherT = TypeVar("OtherT")
def _get_tensor_or_else(
weights: Weights, prefix: str, other: OtherT
) -> Union[torch.Tensor, OtherT]:
# Even if a checkpoint uses e.g. zero-points, they can be elided:
# https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105
if weights.has_tensor(prefix):
return weights.get_tensor(prefix, to_dtype=False)
else:
return other
@dataclass
class Int8Weight(Weight):
input_scale: Optional[torch.Tensor]
input_zero_point: Optional[torch.Tensor]
input_symmetric: bool
weight: torch.Tensor
weight_scale: Optional[torch.Tensor]
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
assert marlin_kernels is not None
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
return W8A8IntLinear(
bias=bias,
input_scale=self.input_scale,
input_zero_point=self.input_zero_point,
input_symmetric=self.input_symmetric,
weight=qweight,
weight_scale=weight_scale,
)
else:
return W8A8IntLinear(
bias=bias,
input_scale=self.input_scale,
input_zero_point=self.input_zero_point,
input_symmetric=self.input_symmetric,
weight=self.weight,
weight_scale=self.weight_scale,
)
class W8A8IntLinear(torch.nn.Module):
def __init__(
self,
*,
bias: Optional[torch.Tensor],
input_scale: Optional[torch.Tensor],
input_zero_point: Optional[torch.Tensor],
input_symmetric: bool,
weight: torch.Tensor,
weight_scale: torch.Tensor,
):
super().__init__()
input_scale = (
input_scale.to(torch.float32) if input_scale is not None else input_scale
)
weight_scale = weight_scale.to(torch.float32)
self.bias = bias
self.input_symmetric = input_symmetric
# cutlass kernels require transposed weights.
self.weight = weight.t()
self.weight_scale = weight_scale
if input_scale is not None:
if input_zero_point is None:
# Symmetric: simply use the largest scale to cover fused layers.
input_scale = input_scale.max()
else:
# Asymmetric: find the range that contains all individual ranges.
input_zero_point = input_zero_point.to(torch.int32)
int8_info = torch.iinfo(torch.int8)
# Find the most extreme values of all zero point/input scale
# pairs.
range_min = (input_scale * (int8_info.min - input_zero_point)).min()
range_max = (input_scale * (int8_info.max - input_zero_point)).max()
# Calculate new scale and zero point.
input_scale = (range_max - range_min) / (int8_info.max - int8_info.min)
input_zero_point = int8_info.min - (range_min / input_scale)
input_zero_point = input_zero_point.to(torch.int32)
self.range_min = (
input_scale * (int8_info.min - input_zero_point)
).min()
self.range_max = (
input_scale * (int8_info.max - input_zero_point)
).max()
self.input_scale = input_scale
self.input_zero_point = input_zero_point
if input_symmetric:
self.zero_point_adj = None
else:
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp
self.zero_point_adj = self.weight.sum(
dim=0, keepdim=True, dtype=torch.int32
)
if input_zero_point is not None:
self.zero_point_adj *= input_zero_point
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
input=input,
scale=self.input_scale,
azp=self.input_zero_point,
symmetric=self.input_symmetric,
)
if self.input_symmetric:
return marlin_kernels.cutlass_scaled_mm(
a=qinput,
b=self.weight,
scale_a=input_scale,
scale_b=self.weight_scale,
out_dtype=input.dtype,
bias=self.bias,
)
else:
assert self.zero_point_adj is not None and input_scale is not None
return marlin_kernels.cutlass_scaled_mm_azp(
a=qinput,
b=self.weight,
scale_a=input_scale,
scale_b=self.weight_scale,
out_dtype=input.dtype,
azp_adj=self.zero_point_adj,
# Zero point is already in the adjustment when using static
# input quantization.
azp=input_zero_point if self.input_zero_point is None else None,
bias=self.bias,
)

View File

@ -220,6 +220,7 @@ class Weights:
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
@ -255,7 +256,8 @@ class Weights:
# u4 which are disguised as int32. exl2 uses int16.
# FP8 uses torch.float8_e4m3fn.
if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
tensor.dtype
not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
@ -331,6 +333,7 @@ class Weights:
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int8,
torch.int16,
torch.int32,
torch.int64,