fix: adjust so all tests pass

This commit is contained in:
drbh 2024-06-20 19:28:29 +00:00
parent 7ee217475e
commit 313d29f1f9

View File

@ -85,12 +85,12 @@ dummy_file_system = {
), ),
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32), "weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"weight.scales": torch.tensor([8], dtype=torch.float16), "weight.scales": torch.tensor([[8]], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32),
}, },
"test_get_multi_weights_col_packed_gptq": { "test_get_multi_weights_col_packed_gptq": {
"col_packed.weight.qweight": torch.tensor( "col_packed.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
[3, 4], [3, 4],
@ -99,9 +99,9 @@ dummy_file_system = {
], ],
dtype=torch.int32, dtype=torch.int32,
), ),
"col_packed.weight.g_idx": torch.tensor([1.0], dtype=torch.int32), "col_packed.g_idx": torch.tensor([1.0], dtype=torch.int32),
"col_packed.weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), "col_packed.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"col_packed.weight.scales": torch.tensor([8], dtype=torch.float16), "col_packed.scales": torch.tensor([[8]], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32),
}, },
@ -117,7 +117,7 @@ dummy_file_system = {
), ),
"weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([8], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_col_exl2": { "test_get_multi_weights_col_exl2": {
@ -143,11 +143,11 @@ dummy_file_system = {
}, },
"test_get_multi_weights_col_marlin": { "test_get_multi_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([0.5], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
"test_get_multi_weights_col_packed_marlin": { "test_get_multi_weights_col_packed_marlin": {
"col_packed.weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"col_packed.weight.s": torch.tensor([0.5], dtype=torch.float16), "col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
} }
@ -417,12 +417,12 @@ def test_get_multi_weights_row_gptq():
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8], dtype=torch.int32), scales=torch.tensor([8.0], dtype=torch.float16),
g_idx=torch.tensor([1.0], dtype=torch.float32), g_idx=torch.tensor([1], dtype=torch.int32),
bits=torch.tensor([8], dtype=torch.float32), bits=8.0,
groupsize=torch.tensor([4], dtype=torch.float32), groupsize=4.0,
use_exllama=False, use_exllama=False,
) )
@ -456,12 +456,12 @@ def test_get_multi_weights_col_gptq():
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8], dtype=torch.int32), scales=torch.tensor([[8.0]], dtype=torch.float16),
g_idx=torch.tensor([1.0], dtype=torch.float32), g_idx=torch.tensor([1], dtype=torch.int32),
bits=torch.tensor([8], dtype=torch.float32), bits=8.0,
groupsize=torch.tensor([4], dtype=torch.float32), groupsize=4.0,
use_exllama=False, use_exllama=False,
) )
@ -495,12 +495,12 @@ def test_get_multi_weights_col_packed_gptq():
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8], dtype=torch.int32), scales=torch.tensor([[8.0]], dtype=torch.float16),
g_idx=torch.tensor([1.0], dtype=torch.float32), g_idx=torch.tensor([1], dtype=torch.int32),
bits=torch.tensor([8], dtype=torch.float32), bits=8.0,
groupsize=torch.tensor([4], dtype=torch.float32), groupsize=4.0,
use_exllama=False, use_exllama=False,
) )
@ -532,18 +532,21 @@ def test_get_multi_weights_row_exl2():
quantize=quantize, quantize=quantize,
) )
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight( expected_weight = Exl2Weight(
q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1.0], dtype=torch.float32), q_invperm=torch.tensor([1], dtype=torch.int16),
q_scale_max=8, q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int32), q_groups=torch.tensor([4], dtype=torch.int16),
) )
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert w.q_scale_max == expected_weight.q_scale_max assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
@ -561,25 +564,14 @@ def test_get_multi_weights_col_exl2():
prefix = "weight" prefix = "weight"
quantize = "exl2" quantize = "exl2"
w = weights.get_multi_weights_col( try:
prefix=prefix, w = weights.get_multi_weights_col(
quantize=quantize, prefixes=[prefix],
dim=0, quantize=quantize,
) dim=0,
)
expected_weight = Exl2Weight( except ValueError as e:
q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), assert e.args[0] == "get_multi_weights_col is not supported for exl2"
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1.0], dtype=torch.float32),
q_scale_max=8,
q_groups=torch.tensor([4], dtype=torch.int32),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert w.q_scale_max == expected_weight.q_scale_max
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_awq(): def test_get_multi_weights_row_awq():
@ -602,12 +594,12 @@ def test_get_multi_weights_row_awq():
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8], dtype=torch.int32), scales=torch.tensor([8.0], dtype=torch.float16),
g_idx=None, g_idx=None,
bits=torch.tensor([8], dtype=torch.float32), bits=8.0,
groupsize=torch.tensor([4], dtype=torch.float32), groupsize=4.0,
use_exllama=False, use_exllama=False,
) )
@ -654,7 +646,7 @@ def test_get_multi_weights_col_marlin():
"test_get_multi_weights_col_marlin", "test_get_multi_weights_col_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
) )
@ -663,14 +655,14 @@ def test_get_multi_weights_col_marlin():
quantize = "marlin" quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefix=prefix, prefixes=[prefix],
quantize=quantize, quantize=quantize,
dim=0, dim=0,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([0.5], dtype=torch.float16), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
) )
assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.B, expected_weight.B), "B mismatch"
@ -683,7 +675,7 @@ def test_get_multi_weights_col_packed_marlin():
"test_get_multi_weights_col_packed_marlin", "test_get_multi_weights_col_packed_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
) )
@ -692,15 +684,17 @@ def test_get_multi_weights_col_packed_marlin():
quantize = "marlin" quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefix=prefix, prefixes=[prefix],
quantize=quantize, quantize=quantize,
dim=0, dim=0,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([0.5], dtype=torch.float16), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
) )
print(expected_weight)
assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"