From d611f0f5e2f266f91b8aff8648955c2bed7c6124 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 13 Jan 2025 21:53:11 +0000 Subject: [PATCH] feat: improve weight that support adapters and add tests for starcoder with lora --- .../test_flash_starcoder2.json | 73 ++++ .../test_flash_starcoder2_default_params.json | 373 ++++++++++++++++++ .../test_flash_starcoder2_load.json | 294 ++++++++++++++ ...flash_starcoder2_with_hugcode_adapter.json | 71 ++++ .../models/test_flash_starcoder2_lora.py | 78 ++++ .../text_generation_server/adapters/lora.py | 11 + .../text_generation_server/models/__init__.py | 3 + .../flash_starcoder2_modeling.py | 65 ++- .../text_generation_server/utils/adapter.py | 6 + 9 files changed, 954 insertions(+), 20 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json create mode 100644 integration-tests/models/test_flash_starcoder2_lora.py diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json new file mode 100644 index 00000000..1bc1e0fd --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.9355469, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40795898, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4599609, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.625, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23242188, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..1311c602 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json @@ -0,0 +1,373 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [], + "seed": 0, + "tokens": [ + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 78, + "logprob": -1.0654297, + "special": false, + "text": "I" + }, + { + "id": 3874, + "logprob": -0.4074707, + "special": false, + "text": " am" + }, + { + "id": 331, + "logprob": -0.12695312, + "special": false, + "text": " a" + }, + { + "id": 2951, + "logprob": -0.4501953, + "special": false, + "text": " software" + }, + { + "id": 46380, + "logprob": -0.15124512, + "special": false, + "text": " engineer" + }, + { + "id": 51, + "logprob": -0.953125, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": -0.66259766, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 8204, + "logprob": -0.95947266, + "special": false, + "text": "What" + }, + { + "id": 458, + "logprob": -0.3869629, + "special": false, + "text": " is" + }, + { + "id": 1390, + "logprob": 0.0, + "special": false, + "text": " your" + }, + { + "id": 27455, + "logprob": -0.07891846, + "special": false, + "text": " favorite" + }, + { + "id": 16100, + "logprob": -0.4074707, + "special": false, + "text": " programming" + }, + { + "id": 2940, + "logprob": 0.0, + "special": false, + "text": " language" + }, + { + "id": 68, + "logprob": 0.0, + "special": false, + "text": "?" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 78, + "logprob": 0.0, + "special": false, + "text": "I" + }, + { + "id": 2144, + "logprob": 0.0, + "special": false, + "text": " like" + }, + { + "id": 5006, + "logprob": -0.10021973, + "special": false, + "text": " Python" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 8204, + "logprob": 0.0, + "special": false, + "text": "What" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 1390, + "logprob": 0.0, + "special": false, + "text": " your" + }, + { + "id": 27455, + "logprob": 0.0, + "special": false, + "text": " favorite" + }, + { + "id": 16100, + "logprob": 0.0, + "special": false, + "text": " programming" + }, + { + "id": 2940, + "logprob": 0.0, + "special": false, + "text": " language" + }, + { + "id": 68, + "logprob": 0.0, + "special": false, + "text": "?" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 78, + "logprob": 0.0, + "special": false, + "text": "I" + }, + { + "id": 2144, + "logprob": 0.0, + "special": false, + "text": " like" + }, + { + "id": 5006, + "logprob": 0.0, + "special": false, + "text": " Python" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 8204, + "logprob": 0.0, + "special": false, + "text": "What" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 1390, + "logprob": 0.0, + "special": false, + "text": " your" + }, + { + "id": 27455, + "logprob": 0.0, + "special": false, + "text": " favorite" + }, + { + "id": 16100, + "logprob": 0.0, + "special": false, + "text": " programming" + }, + { + "id": 2940, + "logprob": 0.0, + "special": false, + "text": " language" + }, + { + "id": 68, + "logprob": 0.0, + "special": false, + "text": "?" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 78, + "logprob": 0.0, + "special": false, + "text": "I" + }, + { + "id": 2144, + "logprob": 0.0, + "special": false, + "text": " like" + }, + { + "id": 5006, + "logprob": 0.0, + "special": false, + "text": " Python" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 8204, + "logprob": 0.0, + "special": false, + "text": "What" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 1390, + "logprob": 0.0, + "special": false, + "text": " your" + }, + { + "id": 27455, + "logprob": 0.0, + "special": false, + "text": " favorite" + }, + { + "id": 16100, + "logprob": 0.0, + "special": false, + "text": " programming" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI am a software engineer.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json new file mode 100644 index 00000000..bf9e3010 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json @@ -0,0 +1,294 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json new file mode 100644 index 00000000..de76dd50 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json @@ -0,0 +1,71 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 100, + "logprob": -0.9824219, + "special": false, + "text": "_" + }, + { + "id": 5879, + "logprob": -0.3017578, + "special": false, + "text": "world" + }, + { + "id": 2284, + "logprob": -0.68652344, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.27734375, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.4482422, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.4296875, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.8544922, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7573242, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.81347656, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "_world():\n print(\"Hello World!\")\n" +} diff --git a/integration-tests/models/test_flash_starcoder2_lora.py b/integration-tests/models/test_flash_starcoder2_lora.py new file mode 100644 index 00000000..878e7d24 --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2_lora.py @@ -0,0 +1,78 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher( + "bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"] + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "who are you?", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "who are you?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + +@pytest.mark.asyncio +async def test_flash_starcoder2_with_hugcode_adapter( + flash_starcoder2, response_snapshot +): + response = requests.post( + f"{flash_starcoder2.base_url}/generate", + headers=flash_starcoder2.headers, + json={ + "inputs": "def print_hello", + "parameters": { + "max_new_tokens": 10, + "adapter_id": "smangrul/starcoder-3b-hugcoder", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == "_world():\n print(\"Hello World!\")\n" + + assert data == response_snapshot diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index f1edd9a0..cdcfe91b 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -6,9 +6,11 @@ from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple, Type, Union +from loguru import logger import torch from peft import LoraConfig as _LoraConfig from torch.distributed import ProcessGroup +from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap @@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights): lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers + # import ipdb; ipdb.set_trace() for layer_id in range(nlayers): key = (layer_id, layer_type) + if key not in target_to_layer: + # There is no layer of this type in the model + log_master( + logger.warning, + f"Key specified in lora weights but not found in base model: {key}", + ) + return None + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index beefeb01..e2d24643 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters( "up_proj", "down_proj", "qkv_proj", + # add c_* layers used in starcoder2 + "c_proj", + "c_fc", ] for layer_name in adapter_layers: diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 571dc48e..5e090369 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -112,16 +112,16 @@ class Starcoder2Config(PretrainedConfig): def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: base_layer = _load_gqa(config, prefix, weights) else: - prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] - head_size = config.hidden_size // config.num_attention_heads - sizes = [ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, - ] base_layer = TensorParallelColumnLinear.load_multi( config, prefixes=prefixes, @@ -239,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -292,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module): kv_scales=self.kv_scales, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -310,23 +313,38 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): @@ -379,10 +397,12 @@ class Starcoder2GatedMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -405,7 +425,7 @@ class Starcoder2Layer(nn.Module): ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -432,6 +452,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -447,6 +468,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -454,7 +476,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -501,6 +523,7 @@ class Starcoder2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -524,6 +547,7 @@ class Starcoder2Model(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -595,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 09254b68..50abfafd 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -281,6 +281,12 @@ def get_mlp_weights(i, layer): if hasattr(mlp, "up_proj"): weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + if hasattr(mlp, "c_fc"): + weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc) + + if hasattr(mlp, "c_proj"): + weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj) + if hasattr(mlp, "down_proj"): weights[(i, "down_proj")] = ( f"model.layers.{i}.mlp.down_proj",