mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: improve weight tests
This commit is contained in:
parent
313d29f1f9
commit
29e922d3d4
@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
@ -26,7 +27,7 @@ dummy_file_system = {
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_col_packed": {
|
"test_get_weights_col_packed": {
|
||||||
"col_packed.weight": torch.tensor(
|
"col_packed.weight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -36,7 +37,18 @@ dummy_file_system = {
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
),
|
),
|
||||||
"col_packed_2.weight": torch.tensor(
|
},
|
||||||
|
"test_get_multi_weights_col": {
|
||||||
|
"col.weight": torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
"col.weight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
[3, 4],
|
[3, 4],
|
||||||
@ -57,6 +69,22 @@ dummy_file_system = {
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
"test_get_weights_col_gptq": {
|
||||||
|
"weight.qweight": torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
|
||||||
|
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
|
||||||
|
"weight.scales": torch.tensor([[100.0], [100.0]], dtype=torch.float16),
|
||||||
|
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||||
|
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||||
|
},
|
||||||
"test_get_multi_weights_row_gptq": {
|
"test_get_multi_weights_row_gptq": {
|
||||||
"weight.qweight": torch.tensor(
|
"weight.qweight": torch.tensor(
|
||||||
[
|
[
|
||||||
@ -89,7 +117,7 @@ dummy_file_system = {
|
|||||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_col_packed_gptq": {
|
"test_get_weights_col_packed_gptq": {
|
||||||
"col_packed.qweight": torch.tensor(
|
"col_packed.qweight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -105,6 +133,22 @@ dummy_file_system = {
|
|||||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||||
},
|
},
|
||||||
|
# TODO: review id col packed exl2 is supported
|
||||||
|
"test_get_weights_col_packed_exl2": {
|
||||||
|
"col_packed.q_weight": torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||||
|
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||||
|
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||||
|
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||||
|
},
|
||||||
"test_get_multi_weights_row_exl2": {
|
"test_get_multi_weights_row_exl2": {
|
||||||
"weight.q_weight": torch.tensor(
|
"weight.q_weight": torch.tensor(
|
||||||
[
|
[
|
||||||
@ -134,9 +178,24 @@ dummy_file_system = {
|
|||||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||||
"weight.q_scale_max": torch.tensor([8], dtype=torch.float16),
|
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||||
},
|
},
|
||||||
|
"test_get_weights_col_exl2": {
|
||||||
|
"col_packed.q_weight": torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||||
|
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||||
|
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||||
|
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||||
|
},
|
||||||
"test_get_multi_weights_row_marlin": {
|
"test_get_multi_weights_row_marlin": {
|
||||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
"weight.s": torch.tensor([0.5], dtype=torch.float16),
|
"weight.s": torch.tensor([0.5], dtype=torch.float16),
|
||||||
@ -145,7 +204,7 @@ dummy_file_system = {
|
|||||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_col_packed_marlin": {
|
"test_get_weights_col_packed_marlin": {
|
||||||
"col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
"col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
"col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
"col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||||
},
|
},
|
||||||
@ -295,7 +354,7 @@ def test_get_weights_col_packed():
|
|||||||
|
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_packed",
|
"test_get_weights_col_packed",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -313,9 +372,41 @@ def test_get_weights_col_packed():
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
w,
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_weights_col_packed_block_size():
|
||||||
|
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_packed",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
prefix = "col_packed"
|
prefix = "col_packed"
|
||||||
quantize = None
|
quantize = None
|
||||||
block_sizes = 1
|
block_sizes = 2
|
||||||
|
|
||||||
|
w = weights.get_weights_col_packed(
|
||||||
|
prefix=prefix,
|
||||||
|
quantize=quantize,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
w,
|
w,
|
||||||
@ -331,10 +422,11 @@ def test_get_weights_col_packed():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_packed():
|
def test_get_weights_col_packed_block_size_arr():
|
||||||
|
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_packed",
|
"test_get_weights_col_packed",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -342,7 +434,42 @@ def test_get_multi_weights_col_packed():
|
|||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["col_packed", "col_packed_2"]
|
prefix = "col_packed"
|
||||||
|
quantize = None
|
||||||
|
block_sizes = [1, 1]
|
||||||
|
|
||||||
|
w = weights.get_weights_col_packed(
|
||||||
|
prefix=prefix,
|
||||||
|
quantize=quantize,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
w,
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
[7, 8],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_weights_col():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_multi_weights_col",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefixes = ["col", "col"]
|
||||||
quantize = None
|
quantize = None
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
@ -397,10 +524,10 @@ def test_get_multi_weights_row():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_gptq():
|
def test_get_weights_col_gtpq():
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_gptq",
|
"test_get_weights_col_gptq",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -411,15 +538,15 @@ def test_get_multi_weights_row_gptq():
|
|||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "gptq"
|
quantize = "gptq"
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
|
||||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
scales=torch.tensor([8.0], dtype=torch.float16),
|
scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16),
|
||||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||||
bits=8.0,
|
bits=8.0,
|
||||||
groupsize=4.0,
|
groupsize=4.0,
|
||||||
@ -435,6 +562,262 @@ def test_get_multi_weights_row_gptq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_weights_col_exl2():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_exl2",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "col_packed"
|
||||||
|
quantize = "exl2"
|
||||||
|
|
||||||
|
w = weights.get_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
quantize=quantize,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaled_scale_max = 0.3906 * 256
|
||||||
|
expected_weight = Exl2Weight(
|
||||||
|
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
|
q_scale=torch.tensor([8], dtype=torch.int32),
|
||||||
|
q_invperm=torch.tensor([1], dtype=torch.int16),
|
||||||
|
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
|
||||||
|
q_groups=torch.tensor([4], dtype=torch.int16),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
|
||||||
|
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
|
||||||
|
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
|
||||||
|
assert torch.allclose(
|
||||||
|
w.q_scale_max, expected_weight.q_scale_max
|
||||||
|
), "q_scale_max mismatch"
|
||||||
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_weights_col_packed
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_weights_col_packed_awq():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_packed_gptq",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "col_packed"
|
||||||
|
quantize = "awq"
|
||||||
|
block_sizes = 1
|
||||||
|
|
||||||
|
w = weights.get_weights_col_packed(
|
||||||
|
prefix=prefix,
|
||||||
|
quantize=quantize,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_weight = GPTQWeight(
|
||||||
|
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
|
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||||
|
g_idx=None,
|
||||||
|
bits=8.0,
|
||||||
|
groupsize=4.0,
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
||||||
|
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
||||||
|
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
||||||
|
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||||
|
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||||
|
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||||
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Review expected functionality")
|
||||||
|
def test_get_weights_col_packed_exl2():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_packed_exl2",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "col_packed"
|
||||||
|
quantize = "exl2"
|
||||||
|
block_sizes = 1
|
||||||
|
|
||||||
|
w = weights.get_weights_col_packed(
|
||||||
|
prefix=prefix,
|
||||||
|
quantize=quantize,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaled_scale_max = 0.3906 * 256
|
||||||
|
expected_weight = Exl2Weight(
|
||||||
|
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
|
q_scale=torch.tensor([8], dtype=torch.int32),
|
||||||
|
q_invperm=torch.tensor([1], dtype=torch.int16),
|
||||||
|
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
|
||||||
|
q_groups=torch.tensor([4], dtype=torch.int16),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
|
||||||
|
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
|
||||||
|
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
|
||||||
|
assert torch.allclose(
|
||||||
|
w.q_scale_max, expected_weight.q_scale_max
|
||||||
|
), "q_scale_max mismatch"
|
||||||
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_weights_col_packed_gptq():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_packed_gptq",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefixes = ["col_packed"]
|
||||||
|
quantize = "gptq"
|
||||||
|
|
||||||
|
w = weights.get_multi_weights_col(
|
||||||
|
prefixes=prefixes,
|
||||||
|
quantize=quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_weight = GPTQWeight(
|
||||||
|
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
|
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||||
|
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||||
|
bits=8.0,
|
||||||
|
groupsize=4.0,
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
||||||
|
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
||||||
|
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
||||||
|
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||||
|
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||||
|
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||||
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_weights_col_packed_marlin():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_weights_col_packed_marlin",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float16,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "col_packed"
|
||||||
|
quantize = "marlin"
|
||||||
|
|
||||||
|
w = weights.get_multi_weights_col(
|
||||||
|
prefixes=[prefix],
|
||||||
|
quantize=quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_weight = MarlinWeight(
|
||||||
|
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
|
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(expected_weight)
|
||||||
|
|
||||||
|
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||||
|
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_multi_weights_col
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_weights_col_awq():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_multi_weights_col_gptq",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefixes = ["weight"]
|
||||||
|
quantize = "awq"
|
||||||
|
|
||||||
|
w = weights.get_multi_weights_col(
|
||||||
|
prefixes=prefixes,
|
||||||
|
quantize=quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_weight = GPTQWeight(
|
||||||
|
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
|
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||||
|
g_idx=None,
|
||||||
|
bits=8.0,
|
||||||
|
groupsize=4.0,
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
||||||
|
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
||||||
|
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
||||||
|
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||||
|
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||||
|
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||||
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_weights_col_exl2():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_multi_weights_col_exl2",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "weight"
|
||||||
|
quantize = "exl2"
|
||||||
|
|
||||||
|
try:
|
||||||
|
w = weights.get_multi_weights_col(
|
||||||
|
prefixes=[prefix],
|
||||||
|
quantize=quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_gptq():
|
def test_get_multi_weights_col_gptq():
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
@ -474,10 +857,42 @@ def test_get_multi_weights_col_gptq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_packed_gptq():
|
def test_get_multi_weights_col_marlin():
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_packed_gptq",
|
"test_get_multi_weights_col_marlin",
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float16,
|
||||||
|
process_group=dummy_process_group,
|
||||||
|
dummy_fs=dummy_file_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = "weight"
|
||||||
|
quantize = "marlin"
|
||||||
|
|
||||||
|
w = weights.get_multi_weights_col(
|
||||||
|
prefixes=[prefix],
|
||||||
|
quantize=quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_weight = MarlinWeight(
|
||||||
|
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
|
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||||
|
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_multi_weights_row
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_weights_row_awq():
|
||||||
|
weights = MockWeights(
|
||||||
|
[
|
||||||
|
"test_get_multi_weights_row_gptq",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -485,20 +900,19 @@ def test_get_multi_weights_col_packed_gptq():
|
|||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["col_packed"]
|
prefix = "weight"
|
||||||
quantize = "gptq"
|
quantize = "awq"
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_row(
|
||||||
prefixes=prefixes,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dim=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
scales=torch.tensor([8.0], dtype=torch.float16),
|
||||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
g_idx=None,
|
||||||
bits=8.0,
|
bits=8.0,
|
||||||
groupsize=4.0,
|
groupsize=4.0,
|
||||||
use_exllama=False,
|
use_exllama=False,
|
||||||
@ -507,7 +921,7 @@ def test_get_multi_weights_col_packed_gptq():
|
|||||||
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
||||||
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
||||||
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
||||||
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
@ -550,31 +964,7 @@ def test_get_multi_weights_row_exl2():
|
|||||||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_exl2():
|
def test_get_multi_weights_row_gptq():
|
||||||
weights = MockWeights(
|
|
||||||
[
|
|
||||||
"test_get_multi_weights_col_exl2",
|
|
||||||
],
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.float32,
|
|
||||||
process_group=dummy_process_group,
|
|
||||||
dummy_fs=dummy_file_system,
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix = "weight"
|
|
||||||
quantize = "exl2"
|
|
||||||
|
|
||||||
try:
|
|
||||||
w = weights.get_multi_weights_col(
|
|
||||||
prefixes=[prefix],
|
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_awq():
|
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_gptq",
|
"test_get_multi_weights_row_gptq",
|
||||||
@ -586,7 +976,7 @@ def test_get_multi_weights_row_awq():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "awq"
|
quantize = "gptq"
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_multi_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -597,7 +987,7 @@ def test_get_multi_weights_row_awq():
|
|||||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||||
scales=torch.tensor([8.0], dtype=torch.float16),
|
scales=torch.tensor([8.0], dtype=torch.float16),
|
||||||
g_idx=None,
|
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||||
bits=8.0,
|
bits=8.0,
|
||||||
groupsize=4.0,
|
groupsize=4.0,
|
||||||
use_exllama=False,
|
use_exllama=False,
|
||||||
@ -606,7 +996,7 @@ def test_get_multi_weights_row_awq():
|
|||||||
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
|
||||||
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
|
||||||
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
|
||||||
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
@ -638,63 +1028,3 @@ def test_get_multi_weights_row_marlin():
|
|||||||
|
|
||||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_marlin():
|
|
||||||
weights = MockWeights(
|
|
||||||
[
|
|
||||||
"test_get_multi_weights_col_marlin",
|
|
||||||
],
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.float16,
|
|
||||||
process_group=dummy_process_group,
|
|
||||||
dummy_fs=dummy_file_system,
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix = "weight"
|
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
|
||||||
prefixes=[prefix],
|
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_weight = MarlinWeight(
|
|
||||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
|
||||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
|
||||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_packed_marlin():
|
|
||||||
weights = MockWeights(
|
|
||||||
[
|
|
||||||
"test_get_multi_weights_col_packed_marlin",
|
|
||||||
],
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.float16,
|
|
||||||
process_group=dummy_process_group,
|
|
||||||
dummy_fs=dummy_file_system,
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix = "col_packed"
|
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
|
||||||
prefixes=[prefix],
|
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_weight = MarlinWeight(
|
|
||||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
|
||||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(expected_weight)
|
|
||||||
|
|
||||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
|
||||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
|
||||||
|
Loading…
Reference in New Issue
Block a user