diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json new file mode 100644 index 00000000..bc75bce4 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json @@ -0,0 +1,133 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 108, + "logprob": -0.73046875, + "special": false, + "text": "\n" + }, + { + "id": 30234, + "logprob": -2.328125, + "special": false, + "text": "Brown" + }, + { + "id": 108, + "logprob": -0.12060547, + "special": false, + "text": "\n" + }, + { + "id": 3726, + "logprob": -1.7734375, + "special": false, + "text": "Car" + }, + { + "id": 108, + "logprob": -0.041503906, + "special": false, + "text": "\n" + }, + { + "id": 2915, + "logprob": -1.796875, + "special": false, + "text": "Color" + }, + { + "id": 108, + "logprob": -0.039794922, + "special": false, + "text": "\n" + }, + { + "id": 19178, + "logprob": -1.96875, + "special": false, + "text": "Cool" + }, + { + "id": 108, + "logprob": -0.080566406, + "special": false, + "text": "\n" + }, + { + "id": 40544, + "logprob": -2.1875, + "special": false, + "text": "Decor" + }, + { + "id": 108, + "logprob": -0.033935547, + "special": false, + "text": "\n" + }, + { + "id": 13936, + "logprob": -1.6328125, + "special": false, + "text": "Green" + }, + { + "id": 108, + "logprob": -0.16210938, + "special": false, + "text": "\n" + }, + { + "id": 955, + "logprob": -2.015625, + "special": false, + "text": "..." + }, + { + "id": 108, + "logprob": -0.14746094, + "special": false, + "text": "\n" + }, + { + "id": 955, + "logprob": -0.73828125, + "special": false, + "text": "..." + }, + { + "id": 108, + "logprob": -0.051513672, + "special": false, + "text": "\n" + }, + { + "id": 955, + "logprob": -0.34765625, + "special": false, + "text": "..." + }, + { + "id": 108, + "logprob": -0.020141602, + "special": false, + "text": "\n" + }, + { + "id": 955, + "logprob": -0.11767578, + "special": false, + "text": "..." + } + ], + "top_tokens": null + }, + "generated_text": "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..." +} diff --git a/integration-tests/models/test_flash_pali_gemma2.py b/integration-tests/models/test_flash_pali_gemma2.py new file mode 100644 index 00000000..23705385 --- /dev/null +++ b/integration-tests/models/test_flash_pali_gemma2.py @@ -0,0 +1,29 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_pali_gemma_handle(launcher): + with launcher( + "google/paligemma2-3b-pt-224", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_pali_gemma(flash_pali_gemma_handle): + await flash_pali_gemma_handle.health(300) + return flash_pali_gemma_handle.client + + +async def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot): + car_image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + response = await flash_pali_gemma.generate( + f"![]({car_image})", + max_new_tokens=20, + ) + assert ( + response.generated_text + == "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..." + ) + + assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index e5c44045..82e409a6 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + elif config.model_type == "gemma2": + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + + return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM,