From e850cea85dbcab3bd31e4d4a4b188e7f890c9a35 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 21 Jun 2024 03:25:18 +0000 Subject: [PATCH] fix: tweak shapes --- server/tests/utils/test_weights.py | 161 +++++++++++++++++++---------- 1 file changed, 105 insertions(+), 56 deletions(-) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index aa87f9c3..8f88b1f8 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -79,11 +79,23 @@ dummy_file_system = { ], dtype=torch.float32, ), - "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), - "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "weight.scales": torch.tensor([[100.0], [100.0]], dtype=torch.float16), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), "gptq_bits": torch.tensor([8], dtype=torch.float32), - "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_weights_col_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), @@ -99,11 +111,23 @@ dummy_file_system = { ], dtype=torch.int32, ), - "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), - "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "weight.scales": torch.tensor([8], dtype=torch.float16), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), "gptq_bits": torch.tensor([8], dtype=torch.float32), - "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_multi_weights_col_gptq": { "weight.qweight": torch.tensor( @@ -115,11 +139,23 @@ dummy_file_system = { ], dtype=torch.int32, ), - "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), - "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "weight.scales": torch.tensor([[8]], dtype=torch.float16), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), "gptq_bits": torch.tensor([8], dtype=torch.float32), - "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_weights_col_packed_gptq": { "weight.qweight": torch.tensor( @@ -131,13 +167,24 @@ dummy_file_system = { ], dtype=torch.int32, ), - "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), - "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "weight.scales": torch.tensor([[8]], dtype=torch.float16), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), "gptq_bits": torch.tensor([8], dtype=torch.float32), - "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, - # TODO: review if col packed exl2 is supported "test_get_weights_col_packed_exl2": { "weight.q_weight": torch.tensor( [ @@ -149,7 +196,7 @@ dummy_file_system = { dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, @@ -164,7 +211,7 @@ dummy_file_system = { dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, @@ -179,9 +226,7 @@ dummy_file_system = { dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), - "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, @@ -196,13 +241,13 @@ dummy_file_system = { dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([0.5], dtype=torch.float16), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, "test_get_multi_weights_col_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), @@ -552,11 +597,14 @@ def test_get_weights_col_awq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor( + [[100.0, 100.0], [100.0, 100.0]], + dtype=torch.float16, + ), g_idx=None, bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -590,11 +638,11 @@ def test_get_weights_col_gtpq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16), - g_idx=torch.tensor([1], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -630,7 +678,7 @@ def test_get_weights_col_exl2(): expected_weight = Exl2Weight( q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), - q_invperm=torch.tensor([1], dtype=torch.int16), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), q_groups=torch.tensor([4], dtype=torch.int16), ) @@ -698,11 +746,11 @@ def test_get_weights_col_packed_awq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[8.0]], dtype=torch.float16), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -777,11 +825,11 @@ def test_get_weights_col_packed_gptq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[8.0]], dtype=torch.float16), - g_idx=torch.tensor([1], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -850,11 +898,11 @@ def test_get_multi_weights_col_awq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[8.0]], dtype=torch.float16), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -913,11 +961,11 @@ def test_get_multi_weights_col_gptq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[8.0]], dtype=torch.float16), - g_idx=torch.tensor([1], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -983,11 +1031,11 @@ def test_get_multi_weights_row_awq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([8.0], dtype=torch.float16), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -1018,12 +1066,13 @@ def test_get_multi_weights_row_exl2(): prefix=prefix, quantize=quantize, ) + print(w) scaled_scale_max = 0.3906 * 256 expected_weight = Exl2Weight( q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), - q_invperm=torch.tensor([1], dtype=torch.int16), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), q_groups=torch.tensor([4], dtype=torch.int16), ) @@ -1058,11 +1107,11 @@ def test_get_multi_weights_row_gptq(): expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), - qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([8.0], dtype=torch.float16), - g_idx=torch.tensor([1], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, - groupsize=4.0, + groupsize=2.0, use_exllama=False, ) @@ -1081,7 +1130,7 @@ def test_get_multi_weights_row_marlin(): "test_get_multi_weights_row_marlin", ], device="cpu", - dtype=torch.float32, + dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) @@ -1096,7 +1145,7 @@ def test_get_multi_weights_row_marlin(): expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([0.5], dtype=torch.float16), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) assert torch.allclose(w.B, expected_weight.B), "B mismatch"