mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve weight that support adapters and add tests for starcoder with lora
This commit is contained in:
parent
31778a6508
commit
d611f0f5e2
@ -0,0 +1,73 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.9355469,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40795898,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4599609,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.625,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23242188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2294922,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
}
|
@ -0,0 +1,373 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 60,
|
||||
"prefill": [],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"logprob": -1.0654297,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 3874,
|
||||
"logprob": -0.4074707,
|
||||
"special": false,
|
||||
"text": " am"
|
||||
},
|
||||
{
|
||||
"id": 331,
|
||||
"logprob": -0.12695312,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 2951,
|
||||
"logprob": -0.4501953,
|
||||
"special": false,
|
||||
"text": " software"
|
||||
},
|
||||
{
|
||||
"id": 46380,
|
||||
"logprob": -0.15124512,
|
||||
"special": false,
|
||||
"text": " engineer"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": -0.953125,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.66259766,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 8204,
|
||||
"logprob": -0.95947266,
|
||||
"special": false,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": -0.3869629,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1390,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 27455,
|
||||
"logprob": -0.07891846,
|
||||
"special": false,
|
||||
"text": " favorite"
|
||||
},
|
||||
{
|
||||
"id": 16100,
|
||||
"logprob": -0.4074707,
|
||||
"special": false,
|
||||
"text": " programming"
|
||||
},
|
||||
{
|
||||
"id": 2940,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " language"
|
||||
},
|
||||
{
|
||||
"id": 68,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 2144,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " like"
|
||||
},
|
||||
{
|
||||
"id": 5006,
|
||||
"logprob": -0.10021973,
|
||||
"special": false,
|
||||
"text": " Python"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 8204,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1390,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 27455,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " favorite"
|
||||
},
|
||||
{
|
||||
"id": 16100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " programming"
|
||||
},
|
||||
{
|
||||
"id": 2940,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " language"
|
||||
},
|
||||
{
|
||||
"id": 68,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 2144,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " like"
|
||||
},
|
||||
{
|
||||
"id": 5006,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Python"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 8204,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1390,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 27455,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " favorite"
|
||||
},
|
||||
{
|
||||
"id": 16100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " programming"
|
||||
},
|
||||
{
|
||||
"id": 2940,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " language"
|
||||
},
|
||||
{
|
||||
"id": 68,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 2144,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " like"
|
||||
},
|
||||
{
|
||||
"id": 5006,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Python"
|
||||
},
|
||||
{
|
||||
"id": 51,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 8204,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1390,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 27455,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " favorite"
|
||||
},
|
||||
{
|
||||
"id": 16100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " programming"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI am a software engineer.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming language?\n\nI like Python.\n\nWhat is your favorite programming"
|
||||
}
|
@ -0,0 +1,294 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.9091797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -1.0478516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"logprob": -3.015625,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -1.4228516,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -1.1025391,
|
||||
"special": false,
|
||||
"text": " ["
|
||||
},
|
||||
{
|
||||
"id": 9009,
|
||||
"logprob": -0.0008444786,
|
||||
"special": false,
|
||||
"text": "markdown"
|
||||
},
|
||||
{
|
||||
"id": 98,
|
||||
"logprob": -8.8095665e-05,
|
||||
"special": false,
|
||||
"text": "]"
|
||||
},
|
||||
{
|
||||
"id": 37402,
|
||||
"logprob": -0.5810547,
|
||||
"special": false,
|
||||
"text": " slideshow"
|
||||
},
|
||||
{
|
||||
"id": 8492,
|
||||
"logprob": -0.00022864342,
|
||||
"special": false,
|
||||
"text": "={\""
|
||||
},
|
||||
{
|
||||
"id": 7277,
|
||||
"logprob": -0.00030994415,
|
||||
"special": false,
|
||||
"text": "slide"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||
}
|
||||
]
|
@ -0,0 +1,71 @@
|
||||
{
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.9824219,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5879,
|
||||
"logprob": -0.3017578,
|
||||
"special": false,
|
||||
"text": "world"
|
||||
},
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.68652344,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.4482422,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.54248047,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.4296875,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -0.8544922,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.7573242,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.81347656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_world():\n print(\"Hello World!\")\n"
|
||||
}
|
78
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
78
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder2_handle(launcher):
|
||||
with launcher(
|
||||
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_starcoder2(flash_starcoder2_handle):
|
||||
await flash_starcoder2_handle.health(300)
|
||||
return flash_starcoder2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"who are you?",
|
||||
max_new_tokens=60,
|
||||
temperature=0.2,
|
||||
top_p=0.95,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 60
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_load(
|
||||
flash_starcoder2, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder2_with_hugcode_adapter(
|
||||
flash_starcoder2, response_snapshot
|
||||
):
|
||||
response = requests.post(
|
||||
f"{flash_starcoder2.base_url}/generate",
|
||||
headers=flash_starcoder2.headers,
|
||||
json={
|
||||
"inputs": "def print_hello",
|
||||
"parameters": {
|
||||
"max_new_tokens": 10,
|
||||
"adapter_id": "smangrul/starcoder-3b-hugcoder",
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generated_text"] == "_world():\n print(\"Hello World!\")\n"
|
||||
|
||||
assert data == response_snapshot
|
@ -6,9 +6,11 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
if key not in target_to_layer:
|
||||
# There is no layer of this type in the model
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"Key specified in lora weights but not found in base model: {key}",
|
||||
)
|
||||
return None
|
||||
|
||||
weight_name, layer = target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
@ -1449,6 +1449,9 @@ def get_model_with_lora_adapters(
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"qkv_proj",
|
||||
# add c_* layers used in starcoder2
|
||||
"c_proj",
|
||||
"c_fc",
|
||||
]
|
||||
|
||||
for layer_name in adapter_layers:
|
||||
|
@ -112,16 +112,16 @@ class Starcoder2Config(PretrainedConfig):
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
base_layer = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=prefixes,
|
||||
@ -239,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
@ -292,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
@ -310,23 +313,38 @@ class Starcoder2MLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.c_fc = TensorParallelColumnLinear.load(
|
||||
c_fc = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear.load(
|
||||
c_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||
c_fc,
|
||||
layer_id=index,
|
||||
layer_names=[f"{prefix}.c_fc"],
|
||||
sizes=[config.intermediate_size, config.intermediate_size],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||
c_proj,
|
||||
index,
|
||||
"c_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
return self.c_proj(hidden_states, adapter_data)
|
||||
|
||||
|
||||
class Starcoder2GatedMLP(nn.Module):
|
||||
@ -379,10 +397,12 @@ class Starcoder2GatedMLP(nn.Module):
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
STARCODER2_NORMALIZATION_CLASSES = {
|
||||
@ -405,7 +425,7 @@ class Starcoder2Layer(nn.Module):
|
||||
)
|
||||
|
||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||
)
|
||||
|
||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
@ -432,6 +452,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -447,6 +468,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -454,7 +476,7 @@ class Starcoder2Layer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -501,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
@ -524,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -595,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
|
||||
if hasattr(mlp, "up_proj"):
|
||||
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||
|
||||
if hasattr(mlp, "c_fc"):
|
||||
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
|
||||
|
||||
if hasattr(mlp, "c_proj"):
|
||||
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
|
||||
|
||||
if hasattr(mlp, "down_proj"):
|
||||
weights[(i, "down_proj")] = (
|
||||
f"model.layers.{i}.mlp.down_proj",
|
||||
|
Loading…
Reference in New Issue
Block a user