From 85b5ce653942eb60303883d1cf03bf12a642162f Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 2 Sep 2024 13:57:37 +0000 Subject: [PATCH] fix: add qkv_proj weights to weight test --- server/tests/utils/test_adapter.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index 1eca0d7b..a27c1055 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -70,6 +70,10 @@ def test_get_attn_weights(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -163,6 +167,10 @@ def test_get_attn_weights_llama_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -203,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value,