feat: improve star coder to support multi lora layers (#2883)

* feat: improve star coder to support multi lora layers

* feat: improve weight that support adapters and add tests for starcoder with lora

* fix: bump snapshot for added tests

* fix: rerun pre commit lints

* fix: bump adapter test for added later names
This commit is contained in:
drbh 2025-01-16 16:23:55 -05:00 committed by GitHub
parent 5f78ec32a5
commit 82f6ea1b71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1009 additions and 25 deletions

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.9355469,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40795898,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4599609,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.625,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23242188,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,373 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [],
"seed": 0,
"tokens": [
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -0.7944336,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -0.1796875,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": 0.0,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": 0.0,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": 0.0,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": 0.0,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": 0.0,
"special": false,
"text": "slide"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 700,
"logprob": 0.0,
"special": false,
"text": "type"
},
{
"id": 582,
"logprob": 0.0,
"special": false,
"text": "\":"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 7277,
"logprob": -0.06994629,
"special": false,
"text": "slide"
},
{
"id": 3667,
"logprob": 0.0,
"special": false,
"text": "\"}"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 607,
"logprob": -0.8261719,
"special": false,
"text": " #"
},
{
"id": 244,
"logprob": -1.8574219,
"special": false,
"text": " "
},
{
"id": 55,
"logprob": -1.4541016,
"special": false,
"text": "2"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 6208,
"logprob": -0.9794922,
"special": false,
"text": " What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 341,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 10609,
"logprob": -0.69189453,
"special": false,
"text": " difference"
},
{
"id": 3761,
"logprob": 0.0,
"special": false,
"text": " between"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1168,
"logprob": -0.27172852,
"special": false,
"text": " list"
},
{
"id": 480,
"logprob": 0.0,
"special": false,
"text": " and"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -1.3359375,
"special": false,
"text": "#"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": -0.03164673,
"special": false,
"text": " -"
},
{
"id": 418,
"logprob": -1.0947266,
"special": false,
"text": " A"
},
{
"id": 1168,
"logprob": 0.0,
"special": false,
"text": " list"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 331,
"logprob": -0.3305664,
"special": false,
"text": " a"
},
{
"id": 14792,
"logprob": 0.0,
"special": false,
"text": " mutable"
},
{
"id": 6645,
"logprob": -0.40478516,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": -0.50390625,
"special": false,
"text": " elements"
},
{
"id": 49,
"logprob": -2.1269531,
"special": false,
"text": ","
},
{
"id": 2236,
"logprob": -0.1427002,
"special": false,
"text": " while"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 619,
"logprob": 0.0,
"special": false,
"text": " an"
},
{
"id": 26079,
"logprob": 0.0,
"special": false,
"text": " immutable"
},
{
"id": 6645,
"logprob": 0.0,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": 0.0,
"special": false,
"text": " elements"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " -"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
}

View File

@ -0,0 +1,294 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
}
]

View File

@ -0,0 +1,71 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 100,
"logprob": -0.9824219,
"special": false,
"text": "_"
},
{
"id": 5879,
"logprob": -0.3017578,
"special": false,
"text": "world"
},
{
"id": 2284,
"logprob": -0.68652344,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.27734375,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.4482422,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.54248047,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.4296875,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.8544922,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7573242,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.81347656,
"special": false,
"text": "\n"
}
]
},
"generated_text": "_world():\n print(\"Hello World!\")\n"
}

View File

@ -0,0 +1,79 @@
import pytest
import requests
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher(
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"who are you?",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_with_hugcode_adapter(
flash_starcoder2, response_snapshot
):
response = requests.post(
f"{flash_starcoder2.base_url}/generate",
headers=flash_starcoder2.headers,
json={
"inputs": "def print_hello",
"parameters": {
"max_new_tokens": 10,
"adapter_id": "smangrul/starcoder-3b-hugcoder",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
assert data == response_snapshot

View File

@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
# assert the result
expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),

View File

@ -6,9 +6,11 @@ from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from loguru import logger
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers):
key = (layer_id, layer_type)
if key not in target_to_layer:
# There is no layer of this type in the model
log_master(
logger.warning,
f"Key specified in lora weights but not found in base model: {key}",
)
return None
weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device

View File

@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
"up_proj",
"down_proj",
"qkv_proj",
# add c_* layers used in starcoder2
"c_proj",
"c_fc",
]
for layer_name in adapter_layers:

View File

@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
)
def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
base_layer = _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
prefixes=prefixes,
dim=0,
weights=weights,
bias=config.use_bias,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights):
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=config.use_bias,
bias=getattr(config, "use_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, index):
super().__init__()
act = config.hidden_act
self.act = (
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
)
)
# Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load(
c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=config.use_bias,
)
self.c_proj = TensorParallelRowLinear.load(
c_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=config.use_bias,
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
self.c_fc = TensorParallelMultiAdapterLinear.load(
c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, index, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
prefixes=prefixes,
weights=weights,
dim=0,
bias=config.use_bias,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = {
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
)
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
)
# faster post attention rms norm
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "c_fc"):
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
if hasattr(mlp, "c_proj"):
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj",