mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: tweak shapes
This commit is contained in:
parent
b16109966d
commit
e850cea85d
@ -79,11 +79,23 @@ dummy_file_system = {
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
|
||||
"weight.scales": torch.tensor([[100.0], [100.0]], dtype=torch.float16),
|
||||
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor(
|
||||
[
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.scales": torch.tensor(
|
||||
[
|
||||
[100.0, 100.0],
|
||||
[100.0, 100.0],
|
||||
],
|
||||
dtype=torch.float16,
|
||||
),
|
||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
|
||||
},
|
||||
"test_get_weights_col_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
@ -99,11 +111,23 @@ dummy_file_system = {
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
|
||||
"weight.scales": torch.tensor([8], dtype=torch.float16),
|
||||
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor(
|
||||
[
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.scales": torch.tensor(
|
||||
[
|
||||
[100.0, 100.0],
|
||||
[100.0, 100.0],
|
||||
],
|
||||
dtype=torch.float16,
|
||||
),
|
||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
|
||||
},
|
||||
"test_get_multi_weights_col_gptq": {
|
||||
"weight.qweight": torch.tensor(
|
||||
@ -115,11 +139,23 @@ dummy_file_system = {
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
|
||||
"weight.scales": torch.tensor([[8]], dtype=torch.float16),
|
||||
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor(
|
||||
[
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.scales": torch.tensor(
|
||||
[
|
||||
[100.0, 100.0],
|
||||
[100.0, 100.0],
|
||||
],
|
||||
dtype=torch.float16,
|
||||
),
|
||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
|
||||
},
|
||||
"test_get_weights_col_packed_gptq": {
|
||||
"weight.qweight": torch.tensor(
|
||||
@ -131,13 +167,24 @@ dummy_file_system = {
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.g_idx": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32),
|
||||
"weight.scales": torch.tensor([[8]], dtype=torch.float16),
|
||||
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
"weight.qzeros": torch.tensor(
|
||||
[
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.scales": torch.tensor(
|
||||
[
|
||||
[100.0, 100.0],
|
||||
[100.0, 100.0],
|
||||
],
|
||||
dtype=torch.float16,
|
||||
),
|
||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([4], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
|
||||
},
|
||||
# TODO: review if col packed exl2 is supported
|
||||
"test_get_weights_col_packed_exl2": {
|
||||
"weight.q_weight": torch.tensor(
|
||||
[
|
||||
@ -149,7 +196,7 @@ dummy_file_system = {
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
|
||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
@ -164,7 +211,7 @@ dummy_file_system = {
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
|
||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
@ -179,9 +226,7 @@ dummy_file_system = {
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
|
||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
@ -196,13 +241,13 @@ dummy_file_system = {
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1.0], dtype=torch.int32),
|
||||
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
|
||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_multi_weights_row_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([0.5], dtype=torch.float16),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_multi_weights_col_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
@ -552,11 +597,14 @@ def test_get_weights_col_awq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor(
|
||||
[[100.0, 100.0], [100.0, 100.0]],
|
||||
dtype=torch.float16,
|
||||
),
|
||||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -590,11 +638,11 @@ def test_get_weights_col_gtpq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0], [100.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -630,7 +678,7 @@ def test_get_weights_col_exl2():
|
||||
expected_weight = Exl2Weight(
|
||||
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
q_scale=torch.tensor([8], dtype=torch.int32),
|
||||
q_invperm=torch.tensor([1], dtype=torch.int16),
|
||||
q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),
|
||||
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
|
||||
q_groups=torch.tensor([4], dtype=torch.int16),
|
||||
)
|
||||
@ -698,11 +746,11 @@ def test_get_weights_col_packed_awq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -777,11 +825,11 @@ def test_get_weights_col_packed_gptq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -850,11 +898,11 @@ def test_get_multi_weights_col_awq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -913,11 +961,11 @@ def test_get_multi_weights_col_gptq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([[8.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -983,11 +1031,11 @@ def test_get_multi_weights_row_awq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([8.0], dtype=torch.float16),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -1018,12 +1066,13 @@ def test_get_multi_weights_row_exl2():
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
print(w)
|
||||
|
||||
scaled_scale_max = 0.3906 * 256
|
||||
expected_weight = Exl2Weight(
|
||||
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
q_scale=torch.tensor([8], dtype=torch.int32),
|
||||
q_invperm=torch.tensor([1], dtype=torch.int16),
|
||||
q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),
|
||||
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
|
||||
q_groups=torch.tensor([4], dtype=torch.int16),
|
||||
)
|
||||
@ -1058,11 +1107,11 @@ def test_get_multi_weights_row_gptq():
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[1], [2]], dtype=torch.int32),
|
||||
scales=torch.tensor([8.0], dtype=torch.float16),
|
||||
g_idx=torch.tensor([1], dtype=torch.int32),
|
||||
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
|
||||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=4.0,
|
||||
groupsize=2.0,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
@ -1081,7 +1130,7 @@ def test_get_multi_weights_row_marlin():
|
||||
"test_get_multi_weights_row_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
)
|
||||
@ -1096,7 +1145,7 @@ def test_get_multi_weights_row_marlin():
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
s=torch.tensor([0.5], dtype=torch.float16),
|
||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
)
|
||||
|
||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||
|
Loading…
Reference in New Issue
Block a user