diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index cc92d199..dc6c73f3 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,3 +1,4 @@ +import pytest import torch from text_generation_server.utils.weights import Weights from text_generation_server.layers.gptq import GPTQWeight @@ -26,7 +27,7 @@ dummy_file_system = { dtype=torch.float32, ), }, - "test_get_multi_weights_col_packed": { + "test_get_weights_col_packed": { "col_packed.weight": torch.tensor( [ [1, 2], @@ -36,7 +37,18 @@ dummy_file_system = { ], dtype=torch.float32, ), - "col_packed_2.weight": torch.tensor( + }, + "test_get_multi_weights_col": { + "col.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "col.weight": torch.tensor( [ [1, 2], [3, 4], @@ -57,6 +69,22 @@ dummy_file_system = { dtype=torch.float32, ), }, + "test_get_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), + "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "weight.scales": torch.tensor([[100.0], [100.0]], dtype=torch.float16), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + }, "test_get_multi_weights_row_gptq": { "weight.qweight": torch.tensor( [ @@ -89,7 +117,7 @@ dummy_file_system = { "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, - "test_get_multi_weights_col_packed_gptq": { + "test_get_weights_col_packed_gptq": { "col_packed.qweight": torch.tensor( [ [1, 2], @@ -105,6 +133,22 @@ dummy_file_system = { "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, + # TODO: review id col packed exl2 is supported + "test_get_weights_col_packed_exl2": { + "col_packed.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "col_packed.q_scale": torch.tensor([8], dtype=torch.int32), + "col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16), + "col_packed.q_groups": torch.tensor([4], dtype=torch.int16), + }, "test_get_multi_weights_row_exl2": { "weight.q_weight": torch.tensor( [ @@ -134,9 +178,24 @@ dummy_file_system = { "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), - "weight.q_scale_max": torch.tensor([8], dtype=torch.float16), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, + "test_get_weights_col_exl2": { + "col_packed.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "col_packed.q_scale": torch.tensor([8], dtype=torch.int32), + "col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16), + "col_packed.q_groups": torch.tensor([4], dtype=torch.int16), + }, "test_get_multi_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([0.5], dtype=torch.float16), @@ -145,7 +204,7 @@ dummy_file_system = { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, - "test_get_multi_weights_col_packed_marlin": { + "test_get_weights_col_packed_marlin": { "col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, @@ -295,7 +354,7 @@ def test_get_weights_col_packed(): weights = MockWeights( [ - "test_get_multi_weights_col_packed", + "test_get_weights_col_packed", ], device="cpu", dtype=torch.float32, @@ -313,9 +372,41 @@ def test_get_weights_col_packed(): block_sizes=block_sizes, ) + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + prefix = "col_packed" quantize = None - block_sizes = 1 + block_sizes = 2 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) assert torch.allclose( w, @@ -331,10 +422,11 @@ def test_get_weights_col_packed(): ) -def test_get_multi_weights_col_packed(): +def test_get_weights_col_packed_block_size_arr(): + weights = MockWeights( [ - "test_get_multi_weights_col_packed", + "test_get_weights_col_packed", ], device="cpu", dtype=torch.float32, @@ -342,7 +434,42 @@ def test_get_multi_weights_col_packed(): dummy_fs=dummy_file_system, ) - prefixes = ["col_packed", "col_packed_2"] + prefix = "col_packed" + quantize = None + block_sizes = [1, 1] + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_col(): + weights = MockWeights( + [ + "test_get_multi_weights_col", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["col", "col"] quantize = None w = weights.get_multi_weights_col( @@ -397,10 +524,10 @@ def test_get_multi_weights_row(): ) -def test_get_multi_weights_row_gptq(): +def test_get_weights_col_gtpq(): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_col_gptq", ], device="cpu", dtype=torch.float32, @@ -411,15 +538,15 @@ def test_get_multi_weights_row_gptq(): prefix = "weight" quantize = "gptq" - w = weights.get_multi_weights_row( + w = weights.get_weights_col( prefix=prefix, quantize=quantize, ) expected_weight = GPTQWeight( - qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([8.0], dtype=torch.float16), + scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16), g_idx=torch.tensor([1], dtype=torch.int32), bits=8.0, groupsize=4.0, @@ -435,6 +562,262 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" +def test_get_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "col_packed" + quantize = "exl2" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +# test_get_weights_col_packed + + +def test_get_weights_col_packed_awq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "col_packed" + quantize = "awq" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[8.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=4.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +@pytest.mark.skip(reason="Review expected functionality") +def test_get_weights_col_packed_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_packed_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "col_packed" + quantize = "exl2" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_packed_gptq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["col_packed"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[8.0]], dtype=torch.float16), + g_idx=torch.tensor([1], dtype=torch.int32), + bits=8.0, + groupsize=4.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_packed_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_packed_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "col_packed" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + print(expected_weight) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_col + + +def test_get_multi_weights_col_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "awq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[8.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=4.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + try: + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + except ValueError as e: + assert e.args[0] == "get_multi_weights_col is not supported for exl2" + + def test_get_multi_weights_col_gptq(): weights = MockWeights( [ @@ -474,10 +857,42 @@ def test_get_multi_weights_col_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_packed_gptq(): +def test_get_multi_weights_col_marlin(): weights = MockWeights( [ - "test_get_multi_weights_col_packed_gptq", + "test_get_multi_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_row + + +def test_get_multi_weights_row_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_row_gptq", ], device="cpu", dtype=torch.float32, @@ -485,20 +900,19 @@ def test_get_multi_weights_col_packed_gptq(): dummy_fs=dummy_file_system, ) - prefixes = ["col_packed"] - quantize = "gptq" + prefix = "weight" + quantize = "awq" - w = weights.get_multi_weights_col( - prefixes=prefixes, + w = weights.get_multi_weights_row( + prefix=prefix, quantize=quantize, - dim=0, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32), - scales=torch.tensor([[8.0]], dtype=torch.float16), - g_idx=torch.tensor([1], dtype=torch.int32), + scales=torch.tensor([8.0], dtype=torch.float16), + g_idx=None, bits=8.0, groupsize=4.0, use_exllama=False, @@ -507,7 +921,7 @@ def test_get_multi_weights_col_packed_gptq(): assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" - assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -550,31 +964,7 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_col_exl2(): - weights = MockWeights( - [ - "test_get_multi_weights_col_exl2", - ], - device="cpu", - dtype=torch.float32, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - ) - - prefix = "weight" - quantize = "exl2" - - try: - w = weights.get_multi_weights_col( - prefixes=[prefix], - quantize=quantize, - dim=0, - ) - except ValueError as e: - assert e.args[0] == "get_multi_weights_col is not supported for exl2" - - -def test_get_multi_weights_row_awq(): +def test_get_multi_weights_row_gptq(): weights = MockWeights( [ "test_get_multi_weights_row_gptq", @@ -586,7 +976,7 @@ def test_get_multi_weights_row_awq(): ) prefix = "weight" - quantize = "awq" + quantize = "gptq" w = weights.get_multi_weights_row( prefix=prefix, @@ -597,7 +987,7 @@ def test_get_multi_weights_row_awq(): qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32), scales=torch.tensor([8.0], dtype=torch.float16), - g_idx=None, + g_idx=torch.tensor([1], dtype=torch.int32), bits=8.0, groupsize=4.0, use_exllama=False, @@ -606,7 +996,7 @@ def test_get_multi_weights_row_awq(): assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" - assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -638,63 +1028,3 @@ def test_get_multi_weights_row_marlin(): assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - -def test_get_multi_weights_col_marlin(): - weights = MockWeights( - [ - "test_get_multi_weights_col_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - ) - - prefix = "weight" - quantize = "marlin" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - quantize=quantize, - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - -def test_get_multi_weights_col_packed_marlin(): - weights = MockWeights( - [ - "test_get_multi_weights_col_packed_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - ) - - prefix = "col_packed" - quantize = "marlin" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - quantize=quantize, - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - print(expected_weight) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch"