From 70713fc2929e1d6f92051d60202fb2991eca6089 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 14 May 2024 23:09:28 +0000 Subject: [PATCH] fix: improve pali test and add snapshot --- .../test_flash_pali_gemma.json | 25 +++++++++++++++++++ .../models/test_flash_pali_gemma.py | 7 +++--- 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json new file mode 100644 index 00000000..f78428fb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json @@ -0,0 +1,25 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 2, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 54901, + "logprob": -0.61621094, + "special": false, + "text": "beach" + }, + { + "id": 1, + "logprob": -0.11273193, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "beach" +} diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index e1fad0d3..61d816e1 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -7,7 +7,7 @@ import base64 @pytest.fixture(scope="module") def flash_pali_gemma_handle(launcher): with launcher( - "Tinkering/test-bvhf", + "gv-hf/paligemma-3b-mix-224", num_shard=1, max_input_length=4000, max_total_tokens=4096, @@ -31,7 +31,8 @@ def get_cow_beach(): @pytest.mark.private async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): cow = get_cow_beach() - inputs = f"Where is the cow standing?\n![]({cow})" + inputs = f"![]({cow})Where is the cow standing?\n" response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) - assert response.generated_text == "\nbeach" + assert response.generated_text == "beach" + assert response == response_snapshot