mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add test
This commit is contained in:
parent
ae5beb9d7b
commit
1d18dbd47e
44
server/tests/utils/test_adapter.py
Normal file
44
server/tests/utils/test_adapter.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
|
||||
from text_generation_server.utils.adapter import merge_adapter_weights
|
||||
|
||||
|
||||
def test_merge_adapter_weights():
|
||||
W_0 = torch.tensor([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]
|
||||
])
|
||||
model_weights = {
|
||||
"model.layers.10.self_attn.q_proj.weight": W_0
|
||||
}
|
||||
|
||||
A = torch.tensor([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
])
|
||||
B = torch.tensor([
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[5, 6]
|
||||
])
|
||||
adapter_weights = {
|
||||
"base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight": A,
|
||||
"base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight": B
|
||||
}
|
||||
|
||||
W_expected = torch.tensor([
|
||||
[ 5.5000, 8.0000, 10.5000],
|
||||
[13.5000, 18.0000, 22.5000],
|
||||
[21.5000, 28.0000, 34.5000]
|
||||
])
|
||||
adapter_config = LoraConfig(r=2, lora_alpha=1, fan_in_fan_out=False)
|
||||
merged_weights, processed_adapter_weight_names = merge_adapter_weights(model_weights, adapter_weights, adapter_config)
|
||||
|
||||
assert len(merged_weights) == 1
|
||||
assert merged_weights["model.layers.10.self_attn.q_proj.weight"].equal(W_expected)
|
||||
|
||||
assert len(processed_adapter_weight_names) == 2
|
||||
assert "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight" in processed_adapter_weight_names
|
||||
assert "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight" in processed_adapter_weight_names
|
Loading…
Reference in New Issue
Block a user