From 82f6ea1b714183f184a3ee8d4aa4fa7d59ac87ab Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 16 Jan 2025 16:23:55 -0500 Subject: [PATCH] feat: improve star coder to support multi lora layers (#2883) * feat: improve star coder to support multi lora layers * feat: improve weight that support adapters and add tests for starcoder with lora * fix: bump snapshot for added tests * fix: rerun pre commit lints * fix: bump adapter test for added later names --- .../test_flash_starcoder2.json | 73 ++++ .../test_flash_starcoder2_default_params.json | 373 ++++++++++++++++++ .../test_flash_starcoder2_load.json | 294 ++++++++++++++ ...flash_starcoder2_with_hugcode_adapter.json | 71 ++++ .../models/test_flash_starcoder2_lora.py | 79 ++++ server/tests/utils/test_adapter.py | 6 + .../text_generation_server/adapters/lora.py | 11 + .../text_generation_server/models/__init__.py | 3 + .../flash_starcoder2_modeling.py | 118 ++++-- .../text_generation_server/utils/adapter.py | 6 + 10 files changed, 1009 insertions(+), 25 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json create mode 100644 integration-tests/models/test_flash_starcoder2_lora.py diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json new file mode 100644 index 000000000..1bc1e0fde --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.9355469, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40795898, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4599609, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.625, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23242188, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json new file mode 100644 index 000000000..ce3831b0e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json @@ -0,0 +1,373 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [], + "seed": 0, + "tokens": [ + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -0.7944336, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -0.1796875, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": 0.0, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": 0.0, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": 0.0, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": 0.0, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": 0.0, + "special": false, + "text": "slide" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 700, + "logprob": 0.0, + "special": false, + "text": "type" + }, + { + "id": 582, + "logprob": 0.0, + "special": false, + "text": "\":" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 7277, + "logprob": -0.06994629, + "special": false, + "text": "slide" + }, + { + "id": 3667, + "logprob": 0.0, + "special": false, + "text": "\"}" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 607, + "logprob": -0.8261719, + "special": false, + "text": " #" + }, + { + "id": 244, + "logprob": -1.8574219, + "special": false, + "text": " " + }, + { + "id": 55, + "logprob": -1.4541016, + "special": false, + "text": "2" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 6208, + "logprob": -0.9794922, + "special": false, + "text": " What" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 341, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 10609, + "logprob": -0.69189453, + "special": false, + "text": " difference" + }, + { + "id": 3761, + "logprob": 0.0, + "special": false, + "text": " between" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 1168, + "logprob": -0.27172852, + "special": false, + "text": " list" + }, + { + "id": 480, + "logprob": 0.0, + "special": false, + "text": " and" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 8871, + "logprob": 0.0, + "special": false, + "text": " tuple" + }, + { + "id": 68, + "logprob": 0.0, + "special": false, + "text": "?" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -1.3359375, + "special": false, + "text": "#" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 449, + "logprob": -0.03164673, + "special": false, + "text": " -" + }, + { + "id": 418, + "logprob": -1.0947266, + "special": false, + "text": " A" + }, + { + "id": 1168, + "logprob": 0.0, + "special": false, + "text": " list" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 331, + "logprob": -0.3305664, + "special": false, + "text": " a" + }, + { + "id": 14792, + "logprob": 0.0, + "special": false, + "text": " mutable" + }, + { + "id": 6645, + "logprob": -0.40478516, + "special": false, + "text": " sequence" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 4725, + "logprob": -0.50390625, + "special": false, + "text": " elements" + }, + { + "id": 49, + "logprob": -2.1269531, + "special": false, + "text": "," + }, + { + "id": 2236, + "logprob": -0.1427002, + "special": false, + "text": " while" + }, + { + "id": 331, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 8871, + "logprob": 0.0, + "special": false, + "text": " tuple" + }, + { + "id": 458, + "logprob": 0.0, + "special": false, + "text": " is" + }, + { + "id": 619, + "logprob": 0.0, + "special": false, + "text": " an" + }, + { + "id": 26079, + "logprob": 0.0, + "special": false, + "text": " immutable" + }, + { + "id": 6645, + "logprob": 0.0, + "special": false, + "text": " sequence" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 4725, + "logprob": 0.0, + "special": false, + "text": " elements" + }, + { + "id": 51, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": 0.0, + "special": false, + "text": "#" + }, + { + "id": 449, + "logprob": 0.0, + "special": false, + "text": " -" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json new file mode 100644 index 000000000..bf9e3010d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json @@ -0,0 +1,294 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 222, + "logprob": -1.9091797, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -1.0478516, + "special": false, + "text": "\n" + }, + { + "id": 40, + "logprob": -3.015625, + "special": false, + "text": "#" + }, + { + "id": 494, + "logprob": -1.4228516, + "special": false, + "text": " +" + }, + { + "id": 447, + "logprob": -1.1025391, + "special": false, + "text": " [" + }, + { + "id": 9009, + "logprob": -0.0008444786, + "special": false, + "text": "markdown" + }, + { + "id": 98, + "logprob": -8.8095665e-05, + "special": false, + "text": "]" + }, + { + "id": 37402, + "logprob": -0.5810547, + "special": false, + "text": " slideshow" + }, + { + "id": 8492, + "logprob": -0.00022864342, + "special": false, + "text": "={\"" + }, + { + "id": 7277, + "logprob": -0.00030994415, + "special": false, + "text": "slide" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# + [markdown] slideshow={\"slide" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json new file mode 100644 index 000000000..de76dd50e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json @@ -0,0 +1,71 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 100, + "logprob": -0.9824219, + "special": false, + "text": "_" + }, + { + "id": 5879, + "logprob": -0.3017578, + "special": false, + "text": "world" + }, + { + "id": 2284, + "logprob": -0.68652344, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.27734375, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.4482422, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.4296875, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.8544922, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7573242, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.81347656, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "_world():\n print(\"Hello World!\")\n" +} diff --git a/integration-tests/models/test_flash_starcoder2_lora.py b/integration-tests/models/test_flash_starcoder2_lora.py new file mode 100644 index 000000000..6480f6699 --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2_lora.py @@ -0,0 +1,79 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher( + "bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"] + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "who are you?", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "who are you?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder2_with_hugcode_adapter( + flash_starcoder2, response_snapshot +): + response = requests.post( + f"{flash_starcoder2.base_url}/generate", + headers=flash_starcoder2.headers, + json={ + "inputs": "def print_hello", + "parameters": { + "max_new_tokens": 10, + "adapter_id": "smangrul/starcoder-3b-hugcoder", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == '_world():\n print("Hello World!")\n' + + assert data == response_snapshot diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index a27c10551..ab0312e4c 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj(): # assert the result expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), @@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility(): result = get_mlp_weights(3, mock_layer) expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), @@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility(): result = get_mlp_weights(3, mock_layer) expected = { + (3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc), + (3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index f1edd9a07..cdcfe91b1 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -6,9 +6,11 @@ from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple, Type, Union +from loguru import logger import torch from peft import LoraConfig as _LoraConfig from torch.distributed import ProcessGroup +from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap @@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights): lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers + # import ipdb; ipdb.set_trace() for layer_id in range(nlayers): key = (layer_id, layer_type) + if key not in target_to_layer: + # There is no layer of this type in the model + log_master( + logger.warning, + f"Key specified in lora weights but not found in base model: {key}", + ) + return None + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index beefeb016..e2d24643e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters( "up_proj", "down_proj", "qkv_proj", + # add c_* layers used in starcoder2 + "c_proj", + "c_fc", ] for layer_name in adapter_layers: diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index c793982d8..5e090369b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -32,6 +32,8 @@ from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module): kv_scales=self.kv_scales, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 09254b68a..50abfafd5 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -281,6 +281,12 @@ def get_mlp_weights(i, layer): if hasattr(mlp, "up_proj"): weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + if hasattr(mlp, "c_fc"): + weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc) + + if hasattr(mlp, "c_proj"): + weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj) + if hasattr(mlp, "down_proj"): weights[(i, "down_proj")] = ( f"model.layers.{i}.mlp.down_proj",