feat: improve weight that support adapters and add tests for starcoder with lora

This commit is contained in:
drbh 2025-01-13 21:53:11 +00:00
parent 31778a6508
commit d611f0f5e2
9 changed files with 954 additions and 20 deletions

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.9355469,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40795898,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4599609,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.625,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23242188,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,373 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [],
"seed": 0,
"tokens": [
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 78,
"logprob": -1.0654297,
"special": false,
"text": "I"
},
{
"id": 3874,
"logprob": -0.4074707,
"special": false,
"text": " am"
},
{
"id": 331,
"logprob": -0.12695312,
"special": false,
"text": " a"
},
{
"id": 2951,
"logprob": -0.4501953,
"special": false,
"text": " software"
},
{
"id": 46380,
"logprob": -0.15124512,
"special": false,
"text": " engineer"
},
{
"id": 51,
"logprob": -0.953125,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": -0.66259766,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 8204,
"logprob": -0.95947266,
"special": false,
"text": "What"
},
{
"id": 458,
"logprob": -0.3869629,
"special": false,
"text": " is"
},
{
"id": 1390,
"logprob": 0.0,
"special": false,
"text": " your"
},
{
"id": 27455,
"logprob": -0.07891846,
"special": false,
"text": " favorite"
},
{
"id": 16100,
"logprob": -0.4074707,
"special": false,
"text": " programming"
},
{
"id": 2940,
"logprob": 0.0,
"special": false,
"text": " language"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 78,
"logprob": 0.0,
"special": false,
"text": "I"
},
{
"id": 2144,
"logprob": 0.0,
"special": false,
"text": " like"
},
{
"id": 5006,
"logprob": -0.10021973,
"special": false,
"text": " Python"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 8204,
"logprob": 0.0,
"special": false,
"text": "What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 1390,
"logprob": 0.0,
"special": false,
"text": " your"
},
{
"id": 27455,
"logprob": 0.0,
"special": false,
"text": " favorite"
},
{
"id": 16100,
"logprob": 0.0,
"special": false,
"text": " programming"
},
{
"id": 2940,
"logprob": 0.0,
"special": false,
"text": " language"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 78,
"logprob": 0.0,
"special": false,
"text": "I"
},
{
"id": 2144,
"logprob": 0.0,
"special": false,
"text": " like"
},
{
"id": 5006,
"logprob": 0.0,
"special": false,
"text": " Python"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 8204,
"logprob": 0.0,
"special": false,
"text": "What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 1390,
"logprob": 0.0,
"special": false,
"text": " your"
},
{
"id": 27455,
"logprob": 0.0,
"special": false,
"text": " favorite"
},
{
"id": 16100,
"logprob": 0.0,
"special": false,
"text": " programming"
},
{
"id": 2940,
"logprob": 0.0,
"special": false,
"text": " language"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 78,
"logprob": 0.0,
"special": false,
"text": "I"
},
{
"id": 2144,
"logprob": 0.0,
"special": false,
"text": " like"
},
{
"id": 5006,
"logprob": 0.0,
"special": false,
"text": " Python"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 8204,
"logprob": 0.0,
"special": false,
"text": "What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 1390,
"logprob": 0.0,
"special": false,
"text": " your"
},
{
"id": 27455,
"logprob": 0.0,
"special": false,
"text": " favorite"
},
{
"id": 16100,
"logprob": 0.0,
"special": false,
"text": " programming"
}
],
"top_tokens": null
},
"generated_text": "\n\nI am a software engineer.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming"
}

View File

@ -0,0 +1,294 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
}
]

View File

@ -0,0 +1,71 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 100,
"logprob": -0.9824219,
"special": false,
"text": "_"
},
{
"id": 5879,
"logprob": -0.3017578,
"special": false,
"text": "world"
},
{
"id": 2284,
"logprob": -0.68652344,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.27734375,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.4482422,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.54248047,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.4296875,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.8544922,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7573242,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.81347656,
"special": false,
"text": "\n"
}
]
},
"generated_text": "_world():\n print(\"Hello World!\")\n"
}

View File

@ -0,0 +1,78 @@
import pytest
import requests
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher(
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"who are you?",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_with_hugcode_adapter(
flash_starcoder2, response_snapshot
):
response = requests.post(
f"{flash_starcoder2.base_url}/generate",
headers=flash_starcoder2.headers,
json={
"inputs": "def print_hello",
"parameters": {
"max_new_tokens": 10,
"adapter_id": "smangrul/starcoder-3b-hugcoder",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == "_world():\n print(\"Hello World!\")\n"
assert data == response_snapshot

View File

@ -6,9 +6,11 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
from loguru import logger
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
if key not in target_to_layer:
# There is no layer of this type in the model
log_master(
logger.warning,
f"Key specified in lora weights but not found in base model: {key}",
)
return None
weight_name, layer = target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device

View File

@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
"up_proj", "up_proj",
"down_proj", "down_proj",
"qkv_proj", "qkv_proj",
# add c_* layers used in starcoder2
"c_proj",
"c_fc",
] ]
for layer_name in adapter_layers: for layer_name in adapter_layers:

View File

@ -112,9 +112,6 @@ class Starcoder2Config(PretrainedConfig):
def load_attention(config, prefix, weights, layer_id): def load_attention(config, prefix, weights, layer_id):
if config.num_attention_heads != config.num_key_value_heads:
base_layer = _load_gqa(config, prefix, weights)
else:
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
sizes = [ sizes = [
@ -122,6 +119,9 @@ def load_attention(config, prefix, weights, layer_id):
head_size * config.num_key_value_heads, head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads, head_size * config.num_key_value_heads,
] ]
if config.num_attention_heads != config.num_key_value_heads:
base_layer = _load_gqa(config, prefix, weights)
else:
base_layer = TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=prefixes, prefixes=prefixes,
@ -239,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -292,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module): class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -310,23 +313,38 @@ class Starcoder2MLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load( c_fc = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.c_fc", prefix=f"{prefix}.c_fc",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.c_proj = TensorParallelRowLinear.load( c_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
def forward(self, hidden_states): self.c_fc = TensorParallelMultiAdapterLinear.load(
hidden_states = self.c_fc(hidden_states) c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states) return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module): class Starcoder2GatedMLP(nn.Module):
@ -379,10 +397,12 @@ class Starcoder2GatedMLP(nn.Module):
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = { STARCODER2_NORMALIZATION_CLASSES = {
@ -405,7 +425,7 @@ class Starcoder2Layer(nn.Module):
) )
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
) )
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -432,6 +452,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -447,6 +468,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
@ -454,7 +476,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
@ -501,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -524,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -595,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
if hasattr(mlp, "up_proj"): if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "c_fc"):
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
if hasattr(mlp, "c_proj"):
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
if hasattr(mlp, "down_proj"): if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = ( weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj", f"model.layers.{i}.mlp.down_proj",