mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: improve star coder to support multi lora layers (#2883)
* feat: improve star coder to support multi lora layers * feat: improve weight that support adapters and add tests for starcoder with lora * fix: bump snapshot for added tests * fix: rerun pre commit lints * fix: bump adapter test for added later names
This commit is contained in:
parent
5f78ec32a5
commit
82f6ea1b71
@ -0,0 +1,73 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.9355469,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40795898,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4599609,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.625,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23242188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2294922,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
}
|
@ -0,0 +1,373 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 60,
|
||||
"prefill": [],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -0.7944336,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -0.1796875,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 700,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "type"
|
||||
},
|
||||
{
|
||||
"id": 582,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\":"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.06994629,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
},
|
||||
{
|
||||
"id": 3667,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\"}"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 607,
|
||||
"logprob": -0.8261719,
|
||||
"special": false,
|
||||
"text": " #"
|
||||
},
|
||||
{
|
||||
"id": 244,
|
||||
"logprob": -1.8574219,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 55,
|
||||
"logprob": -1.4541016,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 6208,
|
||||
"logprob": -0.9794922,
|
||||
"special": false,
|
||||
"text": " What"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 341,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 10609,
|
||||
"logprob": -0.69189453,
|
||||
"special": false,
|
||||
"text": " difference"
|
||||
},
|
||||
{
|
||||
"id": 3761,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " between"
|
||||
},
|
||||
{
|
||||
"id": 331,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1168,
|
||||
"logprob": -0.27172852,
|
||||
"special": false,
|
||||
"text": " list"
|
||||
},
|
||||
{
|
||||
"id": 480,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 331,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8871,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " tuple"
|
||||
},
|
||||
{
|
||||
"id": 68,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -1.3359375,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 449,
|
||||
"logprob": -0.03164673,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 418,
|
||||
"logprob": -1.0947266,
|
||||
"special": false,
|
||||
"text": " A"
|
||||
},
|
||||
{
|
||||
"id": 1168,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " list"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 331,
|
||||
"logprob": -0.3305664,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 14792,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " mutable"
|
||||
},
|
||||
{
|
||||
"id": 6645,
|
||||
"logprob": -0.40478516,
|
||||
"special": false,
|
||||
"text": " sequence"
|
||||
},
|
||||
{
|
||||
"id": 451,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 4725,
|
||||
"logprob": -0.50390625,
|
||||
"special": false,
|
||||
"text": " elements"
|
||||
},
|
||||
{
|
||||
"id": 49,
|
||||
"logprob": -2.1269531,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 2236,
|
||||
"logprob": -0.1427002,
|
||||
"special": false,
|
||||
"text": " while"
|
||||
},
|
||||
{
|
||||
"id": 331,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8871,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " tuple"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 619,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " an"
|
||||
},
|
||||
{
|
||||
"id": 26079,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " immutable"
|
||||
},
|
||||
{
|
||||
"id": 6645,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " sequence"
|
||||
},
|
||||
{
|
||||
"id": 451,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 4725,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " elements"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 449,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
|
||||
}
|
@ -0,0 +1,294 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
}
|
||||
]
|
@ -0,0 +1,71 @@
|
||||
{
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.9824219,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5879,
|
||||
"logprob": -0.3017578,
|
||||
"special": false,
|
||||
"text": "world"
|
||||
},
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.68652344,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.4482422,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.54248047,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.4296875,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -0.8544922,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.7573242,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.81347656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_world():\n print(\"Hello World!\")\n"
|
||||
}
|
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder2_handle(launcher):
|
||||
with launcher(
|
||||
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_starcoder2(flash_starcoder2_handle):
|
||||
await flash_starcoder2_handle.health(300)
|
||||
return flash_starcoder2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"who are you?",
|
||||
max_new_tokens=60,
|
||||
temperature=0.2,
|
||||
top_p=0.95,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 60
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_load(
|
||||
flash_starcoder2, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_with_hugcode_adapter(
|
||||
flash_starcoder2, response_snapshot
|
||||
):
|
||||
response = requests.post(
|
||||
f"{flash_starcoder2.base_url}/generate",
|
||||
headers=flash_starcoder2.headers,
|
||||
json={
|
||||
"inputs": "def print_hello",
|
||||
"parameters": {
|
||||
"max_new_tokens": 10,
|
||||
"adapter_id": "smangrul/starcoder-3b-hugcoder",
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
|
||||
|
||||
assert data == response_snapshot
|
@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
|
||||
|
||||
# assert the result
|
||||
expected = {
|
||||
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
|
||||
result = get_mlp_weights(3, mock_layer)
|
||||
|
||||
expected = {
|
||||
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
|
||||
result = get_mlp_weights(3, mock_layer)
|
||||
|
||||
expected = {
|
||||
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
|
@ -6,9 +6,11 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
if key not in target_to_layer:
|
||||
# There is no layer of this type in the model
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"Key specified in lora weights but not found in base model: {key}",
|
||||
)
|
||||
return None
|
||||
|
||||
weight_name, layer = target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"qkv_proj",
|
||||
# add c_* layers used in starcoder2
|
||||
"c_proj",
|
||||
"c_fc",
|
||||
]
|
||||
|
||||
for layer_name in adapter_layers:
|
||||
|
@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
base_layer = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
prefixes=prefixes,
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
class Starcoder2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
bias=getattr(config, "use_bias", False),
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.c_fc = TensorParallelColumnLinear.load(
|
||||
c_fc = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear.load(
|
||||
c_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||
c_fc,
|
||||
layer_id=index,
|
||||
layer_names=[f"{prefix}.c_fc"],
|
||||
sizes=[config.intermediate_size, config.intermediate_size],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||
c_proj,
|
||||
index,
|
||||
"c_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
return self.c_proj(hidden_states, adapter_data)
|
||||
|
||||
|
||||
class Starcoder2GatedMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
prefixes=prefixes,
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
STARCODER2_NORMALIZATION_CLASSES = {
|
||||
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = Starcoder2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||
)
|
||||
|
||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||
)
|
||||
|
||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
|
||||
if hasattr(mlp, "up_proj"):
|
||||
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||
|
||||
if hasattr(mlp, "c_fc"):
|
||||
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
|
||||
|
||||
if hasattr(mlp, "c_proj"):
|
||||
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
|
||||
|
||||
if hasattr(mlp, "down_proj"):
|
||||
weights[(i, "down_proj")] = (
|
||||
f"model.layers.{i}.mlp.down_proj",
|
||||
|
Loading…
Reference in New Issue
Block a user