feat: improve weight tests

This commit is contained in:
drbh 2024-06-20 20:38:19 +00:00
parent 313d29f1f9
commit 29e922d3d4

View File

@ -1,3 +1,4 @@
import pytest
import torch import torch
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
@ -26,7 +27,7 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_multi_weights_col_packed": { "test_get_weights_col_packed": {
"col_packed.weight": torch.tensor( "col_packed.weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -36,7 +37,18 @@ dummy_file_system = {
], ],
dtype=torch.float32, dtype=torch.float32,
), ),
"col_packed_2.weight": torch.tensor( },
"test_get_multi_weights_col": {
"col.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
"col.weight": torch.tensor(
[ [
[1, 2], [1, 2],
[3, 4], [3, 4],
@ -57,6 +69,22 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_weights_col_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
"weight.scales": torch.tensor([[100.0], [100.0]], dtype=torch.float16),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
},
"test_get_multi_weights_row_gptq": { "test_get_multi_weights_row_gptq": {
"weight.qweight": torch.tensor( "weight.qweight": torch.tensor(
[ [
@ -89,7 +117,7 @@ dummy_file_system = {
"gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32),
}, },
"test_get_multi_weights_col_packed_gptq": { "test_get_weights_col_packed_gptq": {
"col_packed.qweight": torch.tensor( "col_packed.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -105,6 +133,22 @@ dummy_file_system = {
"gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([4], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32),
}, },
# TODO: review id col packed exl2 is supported
"test_get_weights_col_packed_exl2": {
"col_packed.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_exl2": { "test_get_multi_weights_row_exl2": {
"weight.q_weight": torch.tensor( "weight.q_weight": torch.tensor(
[ [
@ -134,9 +178,24 @@ dummy_file_system = {
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([8], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_weights_col_exl2": {
"col_packed.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"col_packed.q_scale": torch.tensor([8], dtype=torch.int32),
"col_packed.q_invperm": torch.tensor([1.0], dtype=torch.int32),
"col_packed.q_scale_max": torch.tensor([100], dtype=torch.float16),
"col_packed.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_marlin": { "test_get_multi_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([0.5], dtype=torch.float16), "weight.s": torch.tensor([0.5], dtype=torch.float16),
@ -145,7 +204,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
"test_get_multi_weights_col_packed_marlin": { "test_get_weights_col_packed_marlin": {
"col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
@ -295,7 +354,7 @@ def test_get_weights_col_packed():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_packed", "test_get_weights_col_packed",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -313,9 +372,41 @@ def test_get_weights_col_packed():
block_sizes=block_sizes, block_sizes=block_sizes,
) )
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_weights_col_packed_block_size():
weights = MockWeights(
[
"test_get_weights_col_packed",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed" prefix = "col_packed"
quantize = None quantize = None
block_sizes = 1 block_sizes = 2
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
assert torch.allclose( assert torch.allclose(
w, w,
@ -331,10 +422,11 @@ def test_get_weights_col_packed():
) )
def test_get_multi_weights_col_packed(): def test_get_weights_col_packed_block_size_arr():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_packed", "test_get_weights_col_packed",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -342,7 +434,42 @@ def test_get_multi_weights_col_packed():
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
) )
prefixes = ["col_packed", "col_packed_2"] prefix = "col_packed"
quantize = None
block_sizes = [1, 1]
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_multi_weights_col():
weights = MockWeights(
[
"test_get_multi_weights_col",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["col", "col"]
quantize = None quantize = None
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
@ -397,10 +524,10 @@ def test_get_multi_weights_row():
) )
def test_get_multi_weights_row_gptq(): def test_get_weights_col_gtpq():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_col_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -411,15 +538,15 @@ def test_get_multi_weights_row_gptq():
prefix = "weight" prefix = "weight"
quantize = "gptq" quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize, quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8.0], dtype=torch.float16), scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16),
g_idx=torch.tensor([1], dtype=torch.int32), g_idx=torch.tensor([1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=4.0, groupsize=4.0,
@ -435,6 +562,262 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_exl2():
weights = MockWeights(
[
"test_get_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "exl2"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1], dtype=torch.int16),
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int16),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
# test_get_weights_col_packed
def test_get_weights_col_packed_awq():
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "awq"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([[8.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=4.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@pytest.mark.skip(reason="Review expected functionality")
def test_get_weights_col_packed_exl2():
weights = MockWeights(
[
"test_get_weights_col_packed_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "exl2"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1], dtype=torch.int16),
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int16),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq():
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["col_packed"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([[8.0]], dtype=torch.float16),
g_idx=torch.tensor([1], dtype=torch.int32),
bits=8.0,
groupsize=4.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin():
weights = MockWeights(
[
"test_get_weights_col_packed_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
print(expected_weight)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_col
def test_get_multi_weights_col_awq():
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([[8.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=4.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_exl2():
weights = MockWeights(
[
"test_get_multi_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
try:
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq(): def test_get_multi_weights_col_gptq():
weights = MockWeights( weights = MockWeights(
[ [
@ -474,10 +857,42 @@ def test_get_multi_weights_col_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_packed_gptq(): def test_get_multi_weights_col_marlin():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_packed_gptq", "test_get_multi_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row
def test_get_multi_weights_row_awq():
weights = MockWeights(
[
"test_get_multi_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -485,20 +900,19 @@ def test_get_multi_weights_col_packed_gptq():
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
) )
prefixes = ["col_packed"] prefix = "weight"
quantize = "gptq" quantize = "awq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_row(
prefixes=prefixes, prefix=prefix,
quantize=quantize, quantize=quantize,
dim=0,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([[8.0]], dtype=torch.float16), scales=torch.tensor([8.0], dtype=torch.float16),
g_idx=torch.tensor([1], dtype=torch.int32), g_idx=None,
bits=8.0, bits=8.0,
groupsize=4.0, groupsize=4.0,
use_exllama=False, use_exllama=False,
@ -507,7 +921,7 @@ def test_get_multi_weights_col_packed_gptq():
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -550,31 +964,7 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_col_exl2(): def test_get_multi_weights_row_gptq():
weights = MockWeights(
[
"test_get_multi_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
try:
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_row_awq():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_multi_weights_row_gptq",
@ -586,7 +976,7 @@ def test_get_multi_weights_row_awq():
) )
prefix = "weight" prefix = "weight"
quantize = "awq" quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_multi_weights_row(
prefix=prefix, prefix=prefix,
@ -597,7 +987,7 @@ def test_get_multi_weights_row_awq():
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[1], [2]], dtype=torch.int32), qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
scales=torch.tensor([8.0], dtype=torch.float16), scales=torch.tensor([8.0], dtype=torch.float16),
g_idx=None, g_idx=torch.tensor([1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=4.0, groupsize=4.0,
use_exllama=False, use_exllama=False,
@ -606,7 +996,7 @@ def test_get_multi_weights_row_awq():
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -638,63 +1028,3 @@ def test_get_multi_weights_row_marlin():
assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
def test_get_multi_weights_col_marlin():
weights = MockWeights(
[
"test_get_multi_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
def test_get_multi_weights_col_packed_marlin():
weights = MockWeights(
[
"test_get_multi_weights_col_packed_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "col_packed"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
print(expected_weight)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"