text-generation-inference/integration-tests/models/test_flash_pali_gemma.py
OlivierDehaene a6a0c97ed9
feat: prefill chunking (#2600)
* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-16 12:49:33 +02:00

49 lines
1.4 KiB
Python

import pytest
@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",
num_shard=1,
revision="float16",
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
inputs = f"![]({cow_beach})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach"
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(
flash_pali_gemma, response_snapshot, chicken, cow_beach
):
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,
)
# Is PaliGemma not able to handle two separate images? At least we
# get output showing that both images are used.
assert (
response.generated_text == "image result for chicken on the beach"
), f"{repr(response.generated_text)}"
assert response == response_snapshot