fix: add qkv_proj weights to weight test

This commit is contained in:
drbh 2024-09-02 13:57:37 +00:00
parent 8666df68d6
commit 85b5ce6539

View File

@ -70,6 +70,10 @@ def test_get_attn_weights():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
@ -163,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
@ -203,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,