diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index dc6c73f3..aa87f9c3 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -28,7 +28,7 @@ dummy_file_system = { ), }, "test_get_weights_col_packed": { - "col_packed.weight": torch.tensor( + "weight.weight": torch.tensor( [ [1, 2], [3, 4], @@ -39,7 +39,7 @@ dummy_file_system = { ), }, "test_get_multi_weights_col": { - "col.weight": torch.tensor( + "weight.weight": torch.tensor( [ [1, 2], [3, 4], @@ -48,7 +48,7 @@ dummy_file_system = { ], dtype=torch.float32, ), - "col.weight": torch.tensor( + "weight.weight": torch.tensor( [ [1, 2], [3, 4], @@ -59,7 +59,7 @@ dummy_file_system = { ), }, "test_get_multi_weights_row": { - "row_packed.weight": torch.tensor( + "weight.weight": torch.tensor( [ [1, 2], [3, 4], @@ -85,6 +85,10 @@ dummy_file_system = { "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, + "test_get_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + }, "test_get_multi_weights_row_gptq": { "weight.qweight": torch.tensor( [ @@ -118,7 +122,7 @@ dummy_file_system = { "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, "test_get_weights_col_packed_gptq": { - "col_packed.qweight": torch.tensor( + "weight.qweight": torch.tensor( [ [1, 2], [3, 4], @@ -127,15 +131,15 @@ dummy_file_system = { ], dtype=torch.int32, ), - "col_packed.g_idx": torch.tensor([1.0], dtype=torch.int32), - "col_packed.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "col_packed.scales": torch.tensor([[8]], dtype=torch.float16), + "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), + "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "weight.scales": torch.tensor([[8]], dtype=torch.float16), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, - # TODO: review id col packed exl2 is supported + # TODO: review if col packed exl2 is supported "test_get_weights_col_packed_exl2": { - "col_packed.q_weight": torch.tensor( + "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], @@ -144,10 +148,10 @@ dummy_file_system = { ], dtype=torch.int32, ), - "col_packed.q_scale": torch.tensor([8], dtype=torch.int32), - "col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32), - "col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16), - "col_packed.q_groups": torch.tensor([4], dtype=torch.int16), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_row_exl2": { "weight.q_weight": torch.tensor( @@ -182,7 +186,7 @@ dummy_file_system = { "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_weights_col_exl2": { - "col_packed.q_weight": torch.tensor( + "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], @@ -191,10 +195,10 @@ dummy_file_system = { ], dtype=torch.int32, ), - "col_packed.q_scale": torch.tensor([8], dtype=torch.int32), - "col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32), - "col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16), - "col_packed.q_groups": torch.tensor([4], dtype=torch.int16), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), @@ -205,8 +209,8 @@ dummy_file_system = { "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, "test_get_weights_col_packed_marlin": { - "col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, } @@ -362,7 +366,7 @@ def test_get_weights_col_packed(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = None block_sizes = 1 @@ -398,7 +402,7 @@ def test_get_weights_col_packed_block_size(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = None block_sizes = 2 @@ -434,7 +438,7 @@ def test_get_weights_col_packed_block_size_arr(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = None block_sizes = [1, 1] @@ -469,7 +473,7 @@ def test_get_multi_weights_col(): dummy_fs=dummy_file_system, ) - prefixes = ["col", "col"] + prefixes = ["weight", "weight"] quantize = None w = weights.get_multi_weights_col( @@ -507,7 +511,7 @@ def test_get_multi_weights_row(): dummy_fs=dummy_file_system, ) - prefix = "row_packed" + prefix = "weight" quantize = None w = weights.get_multi_weights_row( @@ -524,6 +528,47 @@ def test_get_multi_weights_row(): ) +# test_get_weights_col + + +def test_get_weights_col_awq(): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=4.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + def test_get_weights_col_gtpq(): weights = MockWeights( [ @@ -573,7 +618,7 @@ def test_get_weights_col_exl2(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = "exl2" w = weights.get_weights_col( @@ -599,6 +644,34 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" +def test_get_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + # test_get_weights_col_packed @@ -613,7 +686,7 @@ def test_get_weights_col_packed_awq(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = "awq" block_sizes = 1 @@ -654,7 +727,7 @@ def test_get_weights_col_packed_exl2(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = "exl2" block_sizes = 1 @@ -693,7 +766,7 @@ def test_get_weights_col_packed_gptq(): dummy_fs=dummy_file_system, ) - prefixes = ["col_packed"] + prefixes = ["weight"] quantize = "gptq" w = weights.get_multi_weights_col( @@ -732,7 +805,7 @@ def test_get_weights_col_packed_marlin(): dummy_fs=dummy_file_system, ) - prefix = "col_packed" + prefix = "weight" quantize = "marlin" w = weights.get_multi_weights_col(