mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Enable paligemma2 (#2807)
* feat: support loading gemma2 as vlm text model * feat: add test for paligemma2
This commit is contained in:
parent
08f6fa0b59
commit
9f5c9a5e22
@ -0,0 +1,133 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.73046875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30234,
|
||||||
|
"logprob": -2.328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Brown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.12060547,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3726,
|
||||||
|
"logprob": -1.7734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Car"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.041503906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2915,
|
||||||
|
"logprob": -1.796875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Color"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.039794922,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19178,
|
||||||
|
"logprob": -1.96875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Cool"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.080566406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40544,
|
||||||
|
"logprob": -2.1875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Decor"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.033935547,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13936,
|
||||||
|
"logprob": -1.6328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Green"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.16210938,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 955,
|
||||||
|
"logprob": -2.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.14746094,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 955,
|
||||||
|
"logprob": -0.73828125,
|
||||||
|
"special": false,
|
||||||
|
"text": "..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.051513672,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 955,
|
||||||
|
"logprob": -0.34765625,
|
||||||
|
"special": false,
|
||||||
|
"text": "..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.020141602,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 955,
|
||||||
|
"logprob": -0.11767578,
|
||||||
|
"special": false,
|
||||||
|
"text": "..."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
|
||||||
|
}
|
29
integration-tests/models/test_flash_pali_gemma2.py
Normal file
29
integration-tests/models/test_flash_pali_gemma2.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_pali_gemma_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"google/paligemma2-3b-pt-224",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
|
await flash_pali_gemma_handle.health(300)
|
||||||
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot):
|
||||||
|
car_image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
|
||||||
|
response = await flash_pali_gemma.generate(
|
||||||
|
f"",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
||||||
|
elif config.model_type == "gemma2":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
|
FlashGemma2ForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
|
Loading…
Reference in New Issue
Block a user