mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
feat: improve star coder to support multi lora layers (#2883)
* feat: improve star coder to support multi lora layers * feat: improve weight that support adapters and add tests for starcoder with lora * fix: bump snapshot for added tests * fix: rerun pre commit lints * fix: bump adapter test for added later names
This commit is contained in:
parent
5f78ec32a5
commit
82f6ea1b71
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.9355469,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40795898,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4599609,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.625,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23242188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2294922,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
}
|
@ -0,0 +1,373 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 60,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -0.7944336,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -0.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 700,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 582,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.06994629,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3667,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 607,
|
||||||
|
"logprob": -0.8261719,
|
||||||
|
"special": false,
|
||||||
|
"text": " #"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 244,
|
||||||
|
"logprob": -1.8574219,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 55,
|
||||||
|
"logprob": -1.4541016,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 51,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6208,
|
||||||
|
"logprob": -0.9794922,
|
||||||
|
"special": false,
|
||||||
|
"text": " What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 341,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10609,
|
||||||
|
"logprob": -0.69189453,
|
||||||
|
"special": false,
|
||||||
|
"text": " difference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3761,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " between"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1168,
|
||||||
|
"logprob": -0.27172852,
|
||||||
|
"special": false,
|
||||||
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 480,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8871,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " tuple"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 68,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -1.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 449,
|
||||||
|
"logprob": -0.03164673,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 418,
|
||||||
|
"logprob": -1.0947266,
|
||||||
|
"special": false,
|
||||||
|
"text": " A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1168,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": -0.3305664,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14792,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " mutable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6645,
|
||||||
|
"logprob": -0.40478516,
|
||||||
|
"special": false,
|
||||||
|
"text": " sequence"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 451,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4725,
|
||||||
|
"logprob": -0.50390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49,
|
||||||
|
"logprob": -2.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2236,
|
||||||
|
"logprob": -0.1427002,
|
||||||
|
"special": false,
|
||||||
|
"text": " while"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8871,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " tuple"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 619,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26079,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " immutable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6645,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " sequence"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 451,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4725,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 51,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 449,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
|
||||||
|
}
|
@ -0,0 +1,294 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,71 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.9824219,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5879,
|
||||||
|
"logprob": -0.3017578,
|
||||||
|
"special": false,
|
||||||
|
"text": "world"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.68652344,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.4482422,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.54248047,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -0.8544922,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.7573242,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.81347656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "_world():\n print(\"Hello World!\")\n"
|
||||||
|
}
|
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_starcoder2_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_starcoder2(flash_starcoder2_handle):
|
||||||
|
await flash_starcoder2_handle.health(300)
|
||||||
|
return flash_starcoder2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"who are you?",
|
||||||
|
max_new_tokens=60,
|
||||||
|
temperature=0.2,
|
||||||
|
top_p=0.95,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 60
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_load(
|
||||||
|
flash_starcoder2, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_with_hugcode_adapter(
|
||||||
|
flash_starcoder2, response_snapshot
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
f"{flash_starcoder2.base_url}/generate",
|
||||||
|
headers=flash_starcoder2.headers,
|
||||||
|
json={
|
||||||
|
"inputs": "def print_hello",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 10,
|
||||||
|
"adapter_id": "smangrul/starcoder-3b-hugcoder",
|
||||||
|
"details": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
|
||||||
|
|
||||||
|
assert data == response_snapshot
|
@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
|
|||||||
|
|
||||||
# assert the result
|
# assert the result
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
|
|||||||
result = get_mlp_weights(3, mock_layer)
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
|
|||||||
result = get_mlp_weights(3, mock_layer)
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
@ -6,9 +6,11 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
from peft import LoraConfig as _LoraConfig
|
from peft import LoraConfig as _LoraConfig
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
|
|||||||
lora_a_list = [None] * nlayers
|
lora_a_list = [None] * nlayers
|
||||||
lora_b_list = [None] * nlayers
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
|
if key not in target_to_layer:
|
||||||
|
# There is no layer of this type in the model
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
f"Key specified in lora weights but not found in base model: {key}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
weight_name, layer = target_to_layer[key]
|
weight_name, layer = target_to_layer[key]
|
||||||
base_weight = layer.base_layer.linear.weight
|
base_weight = layer.base_layer.linear.weight
|
||||||
base_device = base_weight.device
|
base_device = base_weight.device
|
||||||
|
@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
|
|||||||
"up_proj",
|
"up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
"qkv_proj",
|
"qkv_proj",
|
||||||
|
# add c_* layers used in starcoder2
|
||||||
|
"c_proj",
|
||||||
|
"c_fc",
|
||||||
]
|
]
|
||||||
|
|
||||||
for layer_name in adapter_layers:
|
for layer_name in adapter_layers:
|
||||||
|
@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
|
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
base_layer = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=prefixes,
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
class Starcoder2Attention(torch.nn.Module):
|
class Starcoder2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=getattr(config, "use_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2MLP(nn.Module):
|
class Starcoder2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.c_fc = TensorParallelColumnLinear.load(
|
c_fc = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_fc",
|
prefix=f"{prefix}.c_fc",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.c_proj = TensorParallelRowLinear.load(
|
c_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_proj",
|
prefix=f"{prefix}.c_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||||
hidden_states = self.c_fc(hidden_states)
|
c_fc,
|
||||||
|
layer_id=index,
|
||||||
|
layer_names=[f"{prefix}.c_fc"],
|
||||||
|
sizes=[config.intermediate_size, config.intermediate_size],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
c_proj,
|
||||||
|
index,
|
||||||
|
"c_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
return self.c_proj(hidden_states)
|
return self.c_proj(hidden_states, adapter_data)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2GatedMLP(nn.Module):
|
class Starcoder2GatedMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
STARCODER2_NORMALIZATION_CLASSES = {
|
STARCODER2_NORMALIZATION_CLASSES = {
|
||||||
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"model.layers.{layer_id}"
|
||||||
self.self_attn = Starcoder2Attention(
|
self.self_attn = Starcoder2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
|
|||||||
if hasattr(mlp, "up_proj"):
|
if hasattr(mlp, "up_proj"):
|
||||||
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||||
|
|
||||||
|
if hasattr(mlp, "c_fc"):
|
||||||
|
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
|
||||||
|
|
||||||
|
if hasattr(mlp, "c_proj"):
|
||||||
|
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
|
||||||
|
|
||||||
if hasattr(mlp, "down_proj"):
|
if hasattr(mlp, "down_proj"):
|
||||||
weights[(i, "down_proj")] = (
|
weights[(i, "down_proj")] = (
|
||||||
f"model.layers.{i}.mlp.down_proj",
|
f"model.layers.{i}.mlp.down_proj",
|
||||||
|
Loading…
Reference in New Issue
Block a user