fix: add missing tests and renames

This commit is contained in:
drbh 2024-06-21 02:59:18 +00:00
parent 29e922d3d4
commit b16109966d

View File

@ -28,7 +28,7 @@ dummy_file_system = {
),
},
"test_get_weights_col_packed": {
"col_packed.weight": torch.tensor(
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -39,7 +39,7 @@ dummy_file_system = {
),
},
"test_get_multi_weights_col": {
"col.weight": torch.tensor(
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -48,7 +48,7 @@ dummy_file_system = {
],
dtype=torch.float32,
),
"col.weight": torch.tensor(
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -59,7 +59,7 @@ dummy_file_system = {
),
},
"test_get_multi_weights_row": {
"row_packed.weight": torch.tensor(
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -85,6 +85,10 @@ dummy_file_system = {
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
"test_get_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
},
"test_get_multi_weights_row_gptq": {
"weight.qweight": torch.tensor(
[
@ -118,7 +122,7 @@ dummy_file_system = {
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
"test_get_weights_col_packed_gptq": {
"col_packed.qweight": torch.tensor(
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
@ -127,15 +131,15 @@ dummy_file_system = {
],
dtype=torch.int32,
),
"col_packed.g_idx": torch.tensor([1.0], dtype=torch.int32),
"col_packed.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"col_packed.scales": torch.tensor([[8]], dtype=torch.float16),
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"weight.scales": torch.tensor([[8]], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
# TODO: review id col packed exl2 is supported
# TODO: review if col packed exl2 is supported
"test_get_weights_col_packed_exl2": {
"col_packed.q_weight": torch.tensor(
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -144,10 +148,10 @@ dummy_file_system = {
],
dtype=torch.int32,
),
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_exl2": {
"weight.q_weight": torch.tensor(
@ -182,7 +186,7 @@ dummy_file_system = {
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_weights_col_exl2": {
"col_packed.q_weight": torch.tensor(
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
@ -191,10 +195,10 @@ dummy_file_system = {
],
dtype=torch.int32,
),
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
@ -205,8 +209,8 @@ dummy_file_system = {
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
"test_get_weights_col_packed_marlin": {
"col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
}
@ -362,7 +366,7 @@ def test_get_weights_col_packed():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = None
block_sizes = 1
@ -398,7 +402,7 @@ def test_get_weights_col_packed_block_size():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = None
block_sizes = 2
@ -434,7 +438,7 @@ def test_get_weights_col_packed_block_size_arr():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = None
block_sizes = [1, 1]
@ -469,7 +473,7 @@ def test_get_multi_weights_col():
dummy_fs=dummy_file_system,
)
prefixes = ["col", "col"]
prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col(
@ -507,7 +511,7 @@ def test_get_multi_weights_row():
dummy_fs=dummy_file_system,
)
prefix = "row_packed"
prefix = "weight"
quantize = None
w = weights.get_multi_weights_row(
@ -524,6 +528,47 @@ def test_get_multi_weights_row():
)
# test_get_weights_col
def test_get_weights_col_awq():
weights = MockWeights(
[
"test_get_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "awq"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=4.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq():
weights = MockWeights(
[
@ -573,7 +618,7 @@ def test_get_weights_col_exl2():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col(
@ -599,6 +644,34 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin():
weights = MockWeights(
[
"test_get_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_weights_col_packed
@ -613,7 +686,7 @@ def test_get_weights_col_packed_awq():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = "awq"
block_sizes = 1
@ -654,7 +727,7 @@ def test_get_weights_col_packed_exl2():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = "exl2"
block_sizes = 1
@ -693,7 +766,7 @@ def test_get_weights_col_packed_gptq():
dummy_fs=dummy_file_system,
)
prefixes = ["col_packed"]
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
@ -732,7 +805,7 @@ def test_get_weights_col_packed_marlin():
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(