mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: improve weight loading and add tests
This commit is contained in:
parent
8c3530f705
commit
5ec88a1b51
187
server/tests/utils/test_adapter.py
Normal file
187
server/tests/utils/test_adapter.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights():
|
||||||
|
# create a mock layer
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_with_gate_up_proj():
|
||||||
|
# create a mock layer with gate_up_proj
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_without_gate_up_proj():
|
||||||
|
# create a mock layer without gate_up_proj
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp = Mock(spec=[])
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_mlp_weights(1, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||||
|
def test_get_attn_weights_different_layers(layer_index):
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(layer_index, mock_layer)
|
||||||
|
|
||||||
|
for k in ["q", "k", "v"]:
|
||||||
|
assert (layer_index, f"{k}_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, f"{k}_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.self_attn.{k}_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (layer_index, "o_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, "o_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.self_attn.o_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||||
|
def test_get_mlp_weights_different_layers(layer_index):
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
result = get_mlp_weights(layer_index, mock_layer)
|
||||||
|
|
||||||
|
for k in ["gate", "up", "down"]:
|
||||||
|
assert (layer_index, f"{k}_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, f"{k}_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.mlp.{k}_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights_llama_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_llama_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights_gemma_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_gemma_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_proj = Mock()
|
||||||
|
mock_layer.mlp.up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
|
||||||
|
# This is necessary because the use of `Mock` automatically creates any
|
||||||
|
# attributes that are accessed, even if they don't exist in the actual
|
||||||
|
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
|
||||||
|
# follow the wrong execution path and return an incorrect result.
|
||||||
|
del mock_layer.mlp.gate_up_proj
|
||||||
|
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
@ -231,17 +231,34 @@ def get_attn_weights(i, layer):
|
|||||||
|
|
||||||
def get_mlp_weights(i, layer):
|
def get_mlp_weights(i, layer):
|
||||||
weights = {}
|
weights = {}
|
||||||
|
if hasattr(layer, "mlp"):
|
||||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
mlp = layer.mlp
|
||||||
gup = layer.mlp.gate_up_proj
|
if hasattr(mlp, "gate_up_proj"):
|
||||||
|
# handle combined gate_up_proj (e.g., for some LLaMA variants)
|
||||||
for k in ["gate", "up", "down"]:
|
weights.update(
|
||||||
key = (i, f"{k}_proj")
|
{
|
||||||
value = (
|
(i, "gate_proj"): (
|
||||||
f"model.layers.{i}.mlp.{k}_proj",
|
f"model.layers.{i}.mlp.gate_proj",
|
||||||
gup if k != "down" else layer.mlp.down_proj,
|
mlp.gate_up_proj,
|
||||||
|
),
|
||||||
|
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
|
||||||
|
if hasattr(mlp, "gate_proj"):
|
||||||
|
weights[(i, "gate_proj")] = (
|
||||||
|
f"model.layers.{i}.mlp.gate_proj",
|
||||||
|
mlp.gate_proj,
|
||||||
|
)
|
||||||
|
if hasattr(mlp, "up_proj"):
|
||||||
|
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||||
|
|
||||||
|
if hasattr(mlp, "down_proj"):
|
||||||
|
weights[(i, "down_proj")] = (
|
||||||
|
f"model.layers.{i}.mlp.down_proj",
|
||||||
|
mlp.down_proj,
|
||||||
)
|
)
|
||||||
weights[key] = value
|
|
||||||
|
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user