This commit is contained in:
Felix Marty 2023-07-12 18:31:49 +00:00
parent faa5b52fdc
commit 38c2be5926
4 changed files with 923 additions and 636 deletions

View File

@ -1,93 +1,192 @@
{ {
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": { "details": {
"best_of_sequences": null,
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 10, "generated_tokens": 20,
"seed": null,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
"logprob": null, "text": "def",
"text": "def" "logprob": null
}, },
{ {
"id": 1459, "id": 3226,
"logprob": -5.6289062, "text": " ge",
"text": " print" "logprob": -9.0234375
},
{
"id": 21017,
"text": "ometric",
"logprob": -9.09375
}, },
{ {
"id": 81, "id": 81,
"logprob": -1.6005859, "text": "_",
"text": "_" "logprob": -0.25610352
}, },
{ {
"id": 7656, "id": 6009,
"logprob": -5.9921875, "text": "mean",
"text": "hello" "logprob": -2.1835938
},
{
"id": 26,
"text": "(",
"logprob": -0.29907227
},
{
"id": 62,
"text": "L",
"logprob": -5.6015625
},
{
"id": 44,
"text": ":",
"logprob": -3.0898438
},
{
"id": 1682,
"text": " List",
"logprob": -0.68359375
},
{
"id": 77,
"text": "[",
"logprob": -0.3869629
},
{
"id": 1808,
"text": "float",
"logprob": -0.95751953
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5507812
} }
], ],
"seed": null,
"tokens": [ "tokens": [
{
"id": 2262,
"logprob": -0.7705078,
"special": false,
"text": "():"
},
{ {
"id": 284, "id": 284,
"logprob": -0.2590332, "text": "\n ",
"special": false, "logprob": -1.171875,
"text": "\n " "special": false
}, },
{ {
"id": 1459, "id": 442,
"logprob": -0.39379883, "text": " return",
"special": false, "logprob": -0.9453125,
"text": " print" "special": false
}, },
{ {
"id": 440, "id": 3632,
"logprob": -0.61376953, "text": " sum",
"special": false, "logprob": -1.4013672,
"text": "(\"" "special": false
}, },
{ {
"id": 8279, "id": 26,
"logprob": -0.47338867, "text": "(",
"special": false, "logprob": -0.083618164,
"text": "Hello" "special": false
}, },
{ {
"id": 10896, "id": 62,
"logprob": -1.5068359, "text": "L",
"special": false, "logprob": -0.098083496,
"text": " World" "special": false
}, },
{ {
"id": 657, "id": 27,
"logprob": -0.80810547, "text": ")",
"special": false, "logprob": -0.30493164,
"text": "\")" "special": false
},
{
"id": 517,
"text": " /",
"logprob": -0.4074707,
"special": false
},
{
"id": 2069,
"text": " len",
"logprob": -0.041015625,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.0011863708,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.0005221367,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.0017499924,
"special": false
},
{
"id": 478,
"text": "\n\n",
"logprob": -0.69873047,
"special": false
}, },
{ {
"id": 203, "id": 203,
"logprob": -0.7397461, "text": "\n",
"special": false, "logprob": -0.041229248,
"text": "\n" "special": false
},
{
"id": 203,
"logprob": -0.35229492,
"special": false,
"text": "\n"
}, },
{ {
"id": 589, "id": 589,
"logprob": -1.0371094, "text": "def",
"special": false, "logprob": -0.27929688,
"text": "def" "special": false
},
{
"id": 3226,
"text": " ge",
"logprob": -1.7089844,
"special": false
},
{
"id": 21017,
"text": "ometric",
"logprob": -0.010757446,
"special": false
},
{
"id": 81,
"text": "_",
"logprob": -0.0090408325,
"special": false
},
{
"id": 6009,
"text": "mean",
"logprob": -0.024932861,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.06451416,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.09832764,
"special": false
} }
] ]
}, }
"generated_text": "():\n print(\"Hello World\")\n\ndef"
} }

View File

@ -1,393 +1,192 @@
{ {
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric",
"details": { "details": {
"best_of_sequences": null,
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 60, "generated_tokens": 20,
"seed": 0,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
"logprob": null, "text": "def",
"text": "def" "logprob": null
}, },
{ {
"id": 1459, "id": 3226,
"logprob": -5.6328125, "text": " ge",
"text": " print" "logprob": -9.0234375
},
{
"id": 21017,
"text": "ometric",
"logprob": -9.0859375
}, },
{ {
"id": 81, "id": 81,
"logprob": -1.6035156, "text": "_",
"text": "_" "logprob": -0.25878906
}, },
{ {
"id": 7656, "id": 6009,
"logprob": -5.9882812, "text": "mean",
"text": "hello" "logprob": -2.2109375
},
{
"id": 26,
"text": "(",
"logprob": -0.30371094
},
{
"id": 62,
"text": "L",
"logprob": -5.6054688
},
{
"id": 44,
"text": ":",
"logprob": -3.0722656
},
{
"id": 1682,
"text": " List",
"logprob": -0.6879883
},
{
"id": 77,
"text": "[",
"logprob": -0.38500977
},
{
"id": 1808,
"text": "float",
"logprob": -0.984375
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5351562
} }
], ],
"seed": 0,
"tokens": [ "tokens": [
{
"id": 2262,
"logprob": -0.042999268,
"special": false,
"text": "():"
},
{ {
"id": 284, "id": 284,
"logprob": 0.0, "text": "\n ",
"special": false, "logprob": -0.05831909,
"text": "\n " "special": false
}, },
{ {
"id": 1459, "id": 442,
"logprob": 0.0, "text": " return",
"special": false, "logprob": 0,
"text": " print" "special": false
}, },
{ {
"id": 440, "id": 11665,
"logprob": 0.0, "text": " reduce",
"special": false, "logprob": -0.9741211,
"text": "(\"" "special": false
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 10896,
"logprob": -0.38549805,
"special": false,
"text": " World"
},
{
"id": 657,
"logprob": -0.5229492,
"special": false,
"text": "\")"
},
{
"id": 203,
"logprob": -0.10632324,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": -0.20141602,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
}, },
{ {
"id": 26, "id": 26,
"logprob": 0.0, "text": "(",
"special": false, "logprob": 0,
"text": "(" "special": false
}, },
{ {
"id": 426, "id": 5962,
"logprob": 0.0, "text": "lambda",
"special": false, "logprob": 0,
"text": "name" "special": false
}, },
{ {
"id": 711, "id": 816,
"logprob": 0.0, "text": " x",
"special": false, "logprob": 0,
"text": "):" "special": false
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": -0.16027832,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
}, },
{ {
"id": 30, "id": 30,
"logprob": 0.0, "text": ",",
"special": false, "logprob": 0,
"text": "," "special": false
}, },
{ {
"id": 11442, "id": 533,
"logprob": 0.0, "text": " y",
"special": false, "logprob": 0,
"text": " age" "special": false
}, },
{ {
"id": 711, "id": 44,
"logprob": 0.0, "text": ":",
"special": false, "logprob": 0,
"text": "):" "special": false
}, },
{ {
"id": 284, "id": 816,
"logprob": 0.0, "text": " x",
"special": false, "logprob": 0,
"text": "\n " "special": false
}, },
{ {
"id": 1459, "id": 319,
"logprob": 0.0, "text": " *",
"special": false, "logprob": 0,
"text": " print" "special": false
}, },
{ {
"id": 440, "id": 533,
"logprob": 0.0, "text": " y",
"special": false, "logprob": 0,
"text": "(\"" "special": false
}, },
{ {
"id": 8279, "id": 30,
"logprob": 0.0, "text": ",",
"special": false, "logprob": 0,
"text": "Hello" "special": false
}, },
{ {
"id": 313, "id": 498,
"logprob": 0.0, "text": " L",
"special": false, "logprob": 0,
"text": " \"" "special": false
}, },
{ {
"id": 474, "id": 27,
"logprob": 0.0, "text": ")",
"special": false, "logprob": 0,
"text": " +" "special": false
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 313,
"logprob": -0.6328125,
"special": false,
"text": " \""
},
{
"id": 313,
"logprob": -1.7011719,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 596,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 490,
"logprob": 0.0,
"special": false,
"text": "))"
}, },
{ {
"id": 203, "id": 203,
"logprob": 0.0, "text": "\n",
"special": false, "logprob": -0.11279297,
"text": "\n" "special": false
}, },
{ {
"id": 203, "id": 203,
"logprob": 0.0, "text": "\n",
"special": false, "logprob": 0,
"text": "\n" "special": false
}, },
{ {
"id": 589, "id": 589,
"logprob": 0.0, "text": "def",
"special": false, "logprob": 0,
"text": "def" "special": false
}, },
{ {
"id": 1459, "id": 3226,
"logprob": 0.0, "text": " ge",
"special": false, "logprob": 0,
"text": " print" "special": false
},
{
"id": 21017,
"text": "ometric",
"logprob": 0,
"special": false
} }
] ]
}, }
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
} }

View File

@ -2,55 +2,48 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_santacoder_gptq_handle(launcher): def flash_starcoder_gptq_handle(launcher):
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle: with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
yield handle yield handle
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def flash_santacoder_gptq(flash_santacoder_gptq_handle): async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
await flash_santacoder_gptq_handle.health(300) await flash_starcoder_gptq_handle.health(300)
return flash_santacoder_gptq_handle.client return flash_starcoder_gptq_handle.client
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_santacoder_gptq(flash_santacoder_gptq, response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
response = await flash_santacoder_gptq.generate( response = await flash_starcoder_gptq.generate(
'def sum(L: List[int]):\n"""Sums all elements from the list L."""', max_new_tokens=40, decoder_input_details=True "def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True
) )
# assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_santacoder_gptq_all_params(flash_santacoder_gptq, response_snapshot): async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot):
response = await flash_santacoder_gptq.generate( response = await flash_starcoder_gptq.generate(
'def sum(L: List[int]):\n"""Sums all elements from the list L."""', "def geometric_mean(L: List[float]):",
max_new_tokens=10, max_new_tokens=20,
repetition_penalty=1.2, temperature=0.2,
return_full_text=True, top_p=0.95,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
) )
#assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_santacoder_gptq_load(flash_santacoder_gptq, generate_load, response_snapshot): async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_santacoder_gptq, 'def sum(L: List[int]):\n"""Sums all elements from the list L."""', max_new_tokens=10, n=4) responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=20, n=4)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])