fix: adjust types and add tests

This commit is contained in:
drbh 2024-06-20 14:56:54 -04:00
parent 70e1982ab2
commit 7ee217475e

View File

@ -65,12 +65,43 @@ dummy_file_system = {
[5, 6], [5, 6],
[7, 8], [7, 8],
], ],
dtype=torch.float32, dtype=torch.int32,
), ),
"weight.g_idx": torch.tensor([1.0], dtype=torch.float32), "weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.float32), "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"weight.scales": torch.tensor([8], dtype=torch.int32), "weight.scales": torch.tensor([8], dtype=torch.float16),
# "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
"test_get_multi_weights_col_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"weight.scales": torch.tensor([8], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
"test_get_multi_weights_col_packed_gptq": {
"col_packed.weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"col_packed.weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"col_packed.weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"col_packed.weight.scales": torch.tensor([8], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32),
}, },
@ -82,18 +113,42 @@ dummy_file_system = {
[5, 6], [5, 6],
[7, 8], [7, 8],
], ],
dtype=torch.float32, dtype=torch.int32,
), ),
"weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.float32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": 8, "weight.q_scale_max": torch.tensor([8], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int32), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_col_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([8], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_marlin": { "test_get_multi_weights_row_marlin": {
"weight.scales": torch.tensor([8], dtype=torch.float16),
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([0.5], dtype=torch.float16), "weight.s": torch.tensor([0.5], dtype=torch.float16),
}, },
"test_get_multi_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([0.5], dtype=torch.float16),
},
"test_get_multi_weights_col_packed_marlin": {
"col_packed.weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"col_packed.weight.s": torch.tensor([0.5], dtype=torch.float16),
},
} }
@ -380,6 +435,84 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_gptq():
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32),
scales=torch.tensor([8], dtype=torch.int32),
g_idx=torch.tensor([1.0], dtype=torch.float32),
bits=torch.tensor([8], dtype=torch.float32),
groupsize=torch.tensor([4], dtype=torch.float32),
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_packed_gptq():
weights = MockWeights(
[
"test_get_multi_weights_col_packed_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["col_packed"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32),
scales=torch.tensor([8], dtype=torch.int32),
g_idx=torch.tensor([1.0], dtype=torch.float32),
bits=torch.tensor([8], dtype=torch.float32),
groupsize=torch.tensor([4], dtype=torch.float32),
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2(): def test_get_multi_weights_row_exl2():
weights = MockWeights( weights = MockWeights(
[ [
@ -414,6 +547,41 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_col_exl2():
weights = MockWeights(
[
"test_get_multi_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_col(
prefix=prefix,
quantize=quantize,
dim=0,
)
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1.0], dtype=torch.float32),
q_scale_max=8,
q_groups=torch.tensor([4], dtype=torch.int32),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert w.q_scale_max == expected_weight.q_scale_max
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_awq(): def test_get_multi_weights_row_awq():
weights = MockWeights( weights = MockWeights(
[ [
@ -478,3 +646,61 @@ def test_get_multi_weights_row_marlin():
assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
def test_get_multi_weights_col_marlin():
weights = MockWeights(
[
"test_get_multi_weights_col_marlin",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefix=prefix,
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([0.5], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
def test_get_multi_weights_col_packed_marlin():
weights = MockWeights(
[
"test_get_multi_weights_col_packed_marlin",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefix=prefix,
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([0.5], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"