feat: improve star coder to support multi lora layers (#2883)

* feat: improve star coder to support multi lora layers

* feat: improve weight that support adapters and add tests for starcoder with lora

* fix: bump snapshot for added tests

* fix: rerun pre commit lints

* fix: bump adapter test for added later names
This commit is contained in:
drbh 2025-01-16 16:23:55 -05:00 committed by GitHub
parent 5f78ec32a5
commit 82f6ea1b71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1009 additions and 25 deletions

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.9355469,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40795898,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4599609,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.625,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23242188,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,373 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [],
"seed": 0,
"tokens": [
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -0.7944336,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -0.1796875,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": 0.0,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": 0.0,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": 0.0,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": 0.0,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": 0.0,
"special": false,
"text": "slide"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 700,
"logprob": 0.0,
"special": false,
"text": "type"
},
{
"id": 582,
"logprob": 0.0,
"special": false,
"text": "\":"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 7277,
"logprob": -0.06994629,
"special": false,
"text": "slide"
},
{
"id": 3667,
"logprob": 0.0,
"special": false,
"text": "\"}"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 607,
"logprob": -0.8261719,
"special": false,
"text": " #"
},
{
"id": 244,
"logprob": -1.8574219,
"special": false,
"text": " "
},
{
"id": 55,
"logprob": -1.4541016,
"special": false,
"text": "2"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 6208,
"logprob": -0.9794922,
"special": false,
"text": " What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 341,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 10609,
"logprob": -0.69189453,
"special": false,
"text": " difference"
},
{
"id": 3761,
"logprob": 0.0,
"special": false,
"text": " between"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1168,
"logprob": -0.27172852,
"special": false,
"text": " list"
},
{
"id": 480,
"logprob": 0.0,
"special": false,
"text": " and"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -1.3359375,
"special": false,
"text": "#"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": -0.03164673,
"special": false,
"text": " -"
},
{
"id": 418,
"logprob": -1.0947266,
"special": false,
"text": " A"
},
{
"id": 1168,
"logprob": 0.0,
"special": false,
"text": " list"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 331,
"logprob": -0.3305664,
"special": false,
"text": " a"
},
{
"id": 14792,
"logprob": 0.0,
"special": false,
"text": " mutable"
},
{
"id": 6645,
"logprob": -0.40478516,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": -0.50390625,
"special": false,
"text": " elements"
},
{
"id": 49,
"logprob": -2.1269531,
"special": false,
"text": ","
},
{
"id": 2236,
"logprob": -0.1427002,
"special": false,
"text": " while"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 619,
"logprob": 0.0,
"special": false,
"text": " an"
},
{
"id": 26079,
"logprob": 0.0,
"special": false,
"text": " immutable"
},
{
"id": 6645,
"logprob": 0.0,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": 0.0,
"special": false,
"text": " elements"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " -"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
}

View File

@ -0,0 +1,294 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
}
]

View File

@ -0,0 +1,71 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 100,
"logprob": -0.9824219,
"special": false,
"text": "_"
},
{
"id": 5879,
"logprob": -0.3017578,
"special": false,
"text": "world"
},
{
"id": 2284,
"logprob": -0.68652344,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.27734375,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.4482422,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.54248047,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.4296875,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.8544922,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7573242,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.81347656,
"special": false,
"text": "\n"
}
]
},
"generated_text": "_world():\n print(\"Hello World!\")\n"
}

View File

@ -0,0 +1,79 @@
import pytest
import requests
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher(
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"who are you?",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_with_hugcode_adapter(
flash_starcoder2, response_snapshot
):
response = requests.post(
f"{flash_starcoder2.base_url}/generate",
headers=flash_starcoder2.headers,
json={
"inputs": "def print_hello",
"parameters": {
"max_new_tokens": 10,
"adapter_id": "smangrul/starcoder-3b-hugcoder",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
assert data == response_snapshot

View File

@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
# assert the result # assert the result
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),

View File

@ -6,9 +6,11 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
from loguru import logger
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
if key not in target_to_layer:
# There is no layer of this type in the model
log_master(
logger.warning,
f"Key specified in lora weights but not found in base model: {key}",
)
return None
weight_name, layer = target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device

View File

@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
"up_proj", "up_proj",
"down_proj", "down_proj",
"qkv_proj", "qkv_proj",
# add c_* layers used in starcoder2
"c_proj",
"c_fc",
] ]
for layer_name in adapter_layers: for layer_name in adapter_layers:

View File

@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) base_layer = _load_gqa(config, prefix, weights)
else: else:
return TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=prefixes,
dim=0, dim=0,
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module): class Starcoder2Attention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=getattr(config, "use_bias", False),
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module): class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load( c_fc = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.c_fc", prefix=f"{prefix}.c_fc",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.c_proj = TensorParallelRowLinear.load( c_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
def forward(self, hidden_states): self.c_fc = TensorParallelMultiAdapterLinear.load(
hidden_states = self.c_fc(hidden_states) c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states) return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module): class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=prefixes,
weights=weights, weights=weights,
dim=0, dim=0,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = { STARCODER2_NORMALIZATION_CLASSES = {
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention( self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
) )
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
) )
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
if hasattr(mlp, "up_proj"): if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "c_fc"):
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
if hasattr(mlp, "c_proj"):
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
if hasattr(mlp, "down_proj"): if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = ( weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj", f"model.layers.{i}.mlp.down_proj",