From 0d1bf9e983100206ef8d9e332532ee54564f8483 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 19 Dec 2024 01:54:10 +0000 Subject: [PATCH] feat: consolidate changes with existing vlms and add support and test for smolvlm --- .../test_flash_smolvlm_next_simple_url.json | 61 +++++++++++++++++++ integration-tests/models/test_smolvlm.py | 31 ++++++++++ .../models/custom_modeling/idefics2.py | 2 +- .../models/vlm_causal_lm.py | 7 ++- 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json create mode 100644 integration-tests/models/test_smolvlm.py diff --git a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json new file mode 100644 index 00000000..17a69d0d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.118652344, + "special": false, + "text": " A" + }, + { + "id": 11426, + "logprob": -0.28320312, + "special": false, + "text": " bee" + }, + { + "id": 335, + "logprob": -0.95703125, + "special": false, + "text": " on" + }, + { + "id": 253, + "logprob": -0.06982422, + "special": false, + "text": " a" + }, + { + "id": 11986, + "logprob": -0.49414062, + "special": false, + "text": " pink" + }, + { + "id": 8525, + "logprob": -0.07763672, + "special": false, + "text": " flower" + }, + { + "id": 30, + "logprob": -1.0703125, + "special": false, + "text": "." + }, + { + "id": 49154, + "logprob": -0.092285156, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " A bee on a pink flower." +} diff --git a/integration-tests/models/test_smolvlm.py b/integration-tests/models/test_smolvlm.py new file mode 100644 index 00000000..cd105d84 --- /dev/null +++ b/integration-tests/models/test_smolvlm.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_smolvlm_next_handle(launcher): + with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_smolvlm_next(flash_smolvlm_next_handle): + await flash_smolvlm_next_handle.health(300) + return flash_smolvlm_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot): + ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg" + query = "What is in this image?" + response = await flash_smolvlm_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " A bee on a pink flower." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 8 + assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 6c1d5823..2e499001 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -916,7 +916,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ) config.quantize = None - self.connector = Idefics3Connector( + self.connector = Idefics2Connector( prefix=f"{prefix}.model.connector" if prefix else "model.connector", config=config, weights=weights, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 306da497..c1908d8e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -280,8 +280,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch): raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: + kwargs = {} + match processor.image_processor_class: + case "Idefics3ImageProcessor": + kwargs["return_row_col_info"] = True + image_inputs = processor.image_processor( - images, return_tensors="pt", return_row_col_info=True + images, return_tensors="pt", **kwargs ) else: image_inputs = None