fix tests

This commit is contained in:
Felix Marty 2023-07-13 10:38:08 +00:00
parent 38c2be5926
commit 2ae65b45a8
5 changed files with 471 additions and 707 deletions

View File

@ -1,6 +1,7 @@
{
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
@ -18,173 +19,173 @@
{
"id": 21017,
"text": "ometric",
"logprob": -9.09375
"logprob": -9.0859375
},
{
"id": 81,
"text": "_",
"logprob": -0.25610352
"logprob": -0.25878906
},
{
"id": 6009,
"text": "mean",
"logprob": -2.1835938
"logprob": -2.2109375
},
{
"id": 26,
"text": "(",
"logprob": -0.29907227
"logprob": -0.30371094
},
{
"id": 62,
"text": "L",
"logprob": -5.6015625
"logprob": -5.6054688
},
{
"id": 44,
"text": ":",
"logprob": -3.0898438
"logprob": -3.0722656
},
{
"id": 1682,
"text": " List",
"logprob": -0.68359375
"logprob": -0.6879883
},
{
"id": 77,
"text": "[",
"logprob": -0.3869629
"logprob": -0.38500977
},
{
"id": 1808,
"text": "float",
"logprob": -0.95751953
"logprob": -0.984375
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5507812
"logprob": -2.5351562
}
],
"tokens": [
{
"id": 284,
"text": "\n ",
"logprob": -1.171875,
"logprob": -1.1738281,
"special": false
},
{
"id": 442,
"text": " return",
"logprob": -0.9453125,
"logprob": -0.95947266,
"special": false
},
{
"id": 3632,
"text": " sum",
"logprob": -1.4013672,
"logprob": -1.4199219,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.083618164,
"logprob": -0.085876465,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.098083496,
"logprob": -0.09875488,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.30493164,
"logprob": -0.30517578,
"special": false
},
{
"id": 517,
"text": " /",
"logprob": -0.4074707,
"logprob": -0.42089844,
"special": false
},
{
"id": 2069,
"text": " len",
"logprob": -0.041015625,
"logprob": -0.042053223,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.0011863708,
"logprob": -0.0011806488,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.0005221367,
"logprob": -0.0005259514,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.0017499924,
"logprob": -0.0017633438,
"special": false
},
{
"id": 478,
"text": "\n\n",
"logprob": -0.69873047,
"logprob": -0.69189453,
"special": false
},
{
"id": 203,
"text": "\n",
"logprob": -0.041229248,
"logprob": -0.041870117,
"special": false
},
{
"id": 589,
"text": "def",
"logprob": -0.27929688,
"logprob": -0.27856445,
"special": false
},
{
"id": 3226,
"text": " ge",
"logprob": -1.7089844,
"logprob": -1.7255859,
"special": false
},
{
"id": 21017,
"text": "ometric",
"logprob": -0.010757446,
"logprob": -0.011291504,
"special": false
},
{
"id": 81,
"text": "_",
"logprob": -0.0090408325,
"logprob": -0.008430481,
"special": false
},
{
"id": 6009,
"text": "mean",
"logprob": -0.024932861,
"logprob": -0.025787354,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.06451416,
"logprob": -0.073913574,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.09832764,
"logprob": -0.09967041,
"special": false
}
]

View File

@ -1,6 +1,7 @@
{
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"seed": 0,

View File

@ -17,9 +17,8 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True,
)
assert response.details.generated_tokens == 20
assert response == response_snapshot
@ -35,7 +34,6 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 20
assert response == response_snapshot
@ -43,7 +41,7 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=20, n=4)
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -378,7 +378,7 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlen_prefill,