feat: consolidate changes with existing vlms and add support and test for smolvlm

This commit is contained in:
drbh 2024-12-19 01:54:10 +00:00
parent 064e040ee3
commit 0d1bf9e983
4 changed files with 99 additions and 2 deletions

View File

@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.118652344,
"special": false,
"text": " A"
},
{
"id": 11426,
"logprob": -0.28320312,
"special": false,
"text": " bee"
},
{
"id": 335,
"logprob": -0.95703125,
"special": false,
"text": " on"
},
{
"id": 253,
"logprob": -0.06982422,
"special": false,
"text": " a"
},
{
"id": 11986,
"logprob": -0.49414062,
"special": false,
"text": " pink"
},
{
"id": 8525,
"logprob": -0.07763672,
"special": false,
"text": " flower"
},
{
"id": 30,
"logprob": -1.0703125,
"special": false,
"text": "."
},
{
"id": 49154,
"logprob": -0.092285156,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " A bee on a pink flower."
}

View File

@ -0,0 +1,31 @@
import pytest
@pytest.fixture(scope="module")
def flash_smolvlm_next_handle(launcher):
with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_smolvlm_next(flash_smolvlm_next_handle):
await flash_smolvlm_next_handle.health(300)
return flash_smolvlm_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):
ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"
query = "What is in this image?"
response = await flash_smolvlm_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " A bee on a pink flower."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 8
assert response == response_snapshot

View File

@ -916,7 +916,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
)
config.quantize = None
self.connector = Idefics3Connector(
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,

View File

@ -280,8 +280,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
match processor.image_processor_class:
case "Idefics3ImageProcessor":
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", return_row_col_info=True
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None