fix tests

This commit is contained in:
Felix Marty 2023-07-13 10:38:08 +00:00
parent 38c2be5926
commit 2ae65b45a8
5 changed files with 471 additions and 707 deletions

View File

@ -1,6 +1,7 @@
{ {
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": { "details": {
"best_of_sequences": null,
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": null, "seed": null,
@ -18,173 +19,173 @@
{ {
"id": 21017, "id": 21017,
"text": "ometric", "text": "ometric",
"logprob": -9.09375 "logprob": -9.0859375
}, },
{ {
"id": 81, "id": 81,
"text": "_", "text": "_",
"logprob": -0.25610352 "logprob": -0.25878906
}, },
{ {
"id": 6009, "id": 6009,
"text": "mean", "text": "mean",
"logprob": -2.1835938 "logprob": -2.2109375
}, },
{ {
"id": 26, "id": 26,
"text": "(", "text": "(",
"logprob": -0.29907227 "logprob": -0.30371094
}, },
{ {
"id": 62, "id": 62,
"text": "L", "text": "L",
"logprob": -5.6015625 "logprob": -5.6054688
}, },
{ {
"id": 44, "id": 44,
"text": ":", "text": ":",
"logprob": -3.0898438 "logprob": -3.0722656
}, },
{ {
"id": 1682, "id": 1682,
"text": " List", "text": " List",
"logprob": -0.68359375 "logprob": -0.6879883
}, },
{ {
"id": 77, "id": 77,
"text": "[", "text": "[",
"logprob": -0.3869629 "logprob": -0.38500977
}, },
{ {
"id": 1808, "id": 1808,
"text": "float", "text": "float",
"logprob": -0.95751953 "logprob": -0.984375
}, },
{ {
"id": 10794, "id": 10794,
"text": "]):", "text": "]):",
"logprob": -2.5507812 "logprob": -2.5351562
} }
], ],
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"text": "\n ", "text": "\n ",
"logprob": -1.171875, "logprob": -1.1738281,
"special": false "special": false
}, },
{ {
"id": 442, "id": 442,
"text": " return", "text": " return",
"logprob": -0.9453125, "logprob": -0.95947266,
"special": false "special": false
}, },
{ {
"id": 3632, "id": 3632,
"text": " sum", "text": " sum",
"logprob": -1.4013672, "logprob": -1.4199219,
"special": false "special": false
}, },
{ {
"id": 26, "id": 26,
"text": "(", "text": "(",
"logprob": -0.083618164, "logprob": -0.085876465,
"special": false "special": false
}, },
{ {
"id": 62, "id": 62,
"text": "L", "text": "L",
"logprob": -0.098083496, "logprob": -0.09875488,
"special": false "special": false
}, },
{ {
"id": 27, "id": 27,
"text": ")", "text": ")",
"logprob": -0.30493164, "logprob": -0.30517578,
"special": false "special": false
}, },
{ {
"id": 517, "id": 517,
"text": " /", "text": " /",
"logprob": -0.4074707, "logprob": -0.42089844,
"special": false "special": false
}, },
{ {
"id": 2069, "id": 2069,
"text": " len", "text": " len",
"logprob": -0.041015625, "logprob": -0.042053223,
"special": false "special": false
}, },
{ {
"id": 26, "id": 26,
"text": "(", "text": "(",
"logprob": -0.0011863708, "logprob": -0.0011806488,
"special": false "special": false
}, },
{ {
"id": 62, "id": 62,
"text": "L", "text": "L",
"logprob": -0.0005221367, "logprob": -0.0005259514,
"special": false "special": false
}, },
{ {
"id": 27, "id": 27,
"text": ")", "text": ")",
"logprob": -0.0017499924, "logprob": -0.0017633438,
"special": false "special": false
}, },
{ {
"id": 478, "id": 478,
"text": "\n\n", "text": "\n\n",
"logprob": -0.69873047, "logprob": -0.69189453,
"special": false "special": false
}, },
{ {
"id": 203, "id": 203,
"text": "\n", "text": "\n",
"logprob": -0.041229248, "logprob": -0.041870117,
"special": false "special": false
}, },
{ {
"id": 589, "id": 589,
"text": "def", "text": "def",
"logprob": -0.27929688, "logprob": -0.27856445,
"special": false "special": false
}, },
{ {
"id": 3226, "id": 3226,
"text": " ge", "text": " ge",
"logprob": -1.7089844, "logprob": -1.7255859,
"special": false "special": false
}, },
{ {
"id": 21017, "id": 21017,
"text": "ometric", "text": "ometric",
"logprob": -0.010757446, "logprob": -0.011291504,
"special": false "special": false
}, },
{ {
"id": 81, "id": 81,
"text": "_", "text": "_",
"logprob": -0.0090408325, "logprob": -0.008430481,
"special": false "special": false
}, },
{ {
"id": 6009, "id": 6009,
"text": "mean", "text": "mean",
"logprob": -0.024932861, "logprob": -0.025787354,
"special": false "special": false
}, },
{ {
"id": 26, "id": 26,
"text": "(", "text": "(",
"logprob": -0.06451416, "logprob": -0.073913574,
"special": false "special": false
}, },
{ {
"id": 62, "id": 62,
"text": "L", "text": "L",
"logprob": -0.09832764, "logprob": -0.09967041,
"special": false "special": false
} }
] ]

View File

@ -1,6 +1,7 @@
{ {
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric", "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric",
"details": { "details": {
"best_of_sequences": null,
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": 0, "seed": 0,

View File

@ -17,9 +17,8 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True "def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == response_snapshot
@ -35,7 +34,6 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == response_snapshot
@ -43,7 +41,7 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot): async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=20, n=4) responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])