mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
This PR adds paligemma modeling code Blog post: https://huggingface.co/blog/paligemma Transformers PR: https://github.com/huggingface/transformers/pull/30814 install the latest changes and run with ```bash # get the weights # text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf # run TGI text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf ``` basic example sending various requests ```python from huggingface_hub import InferenceClient client = InferenceClient("http://127.0.0.1:3000") images = [ "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png", ] prompts = [ "What animal is in this image?", "Name three colors in this image.", "What are 10 colors in this image?", "Where is the cow standing?", "answer en Where is the cow standing?", "Is there a bird in the image?", "Is ther a cow in the image?", "Is there a rabbit in the image?", "how many birds are in the image?", "how many rabbits are in the image?", ] for img in images: print(f"\nImage: {img.split('/')[-1]}") for prompt in prompts: inputs = f"{prompt}\n" json_data = { "inputs": inputs, "parameters": { "max_new_tokens": 30, "do_sample": False, }, } generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False) print([f"{prompt}\n{generated_output}"]) ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
import torch
|
|
import torch.distributed
|
|
from opentelemetry import trace
|
|
from typing import Optional, Tuple
|
|
from text_generation_server.models.vlm_causal_lm import (
|
|
VlmCausalLM,
|
|
VlmCausalLMBatch,
|
|
image_text_replacement,
|
|
load_data_uri,
|
|
split,
|
|
)
|
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
|
PaliGemmaForConditionalGeneration,
|
|
)
|
|
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
|
@classmethod
|
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
|
batch_inputs = []
|
|
image_inputs = []
|
|
max_truncation = 0
|
|
for r in requests:
|
|
chunks = split(r.inputs)
|
|
full_text = ""
|
|
image_id = 0
|
|
for chunk in chunks:
|
|
if chunk["type"] == "text":
|
|
full_text += "<bos>" + chunk["content"] + "\n"
|
|
elif chunk["type"] == "image":
|
|
image = chunk["content"]
|
|
# Should never receive URLs anymore, processing should be done
|
|
# On the rust layer.
|
|
# This avoid making n queries per TP
|
|
# if image.startswith("https://") or image.startswith("http://"):
|
|
# image = processor.image_processor.fetch_images(image)
|
|
if image.startswith("data:"):
|
|
image = load_data_uri(image)
|
|
else:
|
|
raise RuntimeError(
|
|
"Cannot process input image not starting with data:"
|
|
)
|
|
# TODO do_convert_RGB should be on by default ?
|
|
image = image.convert("RGB")
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
full_text += image_text_replacement(image_input, config, image_id)
|
|
image_inputs.append(image_input)
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
|
|
|
batch_inputs.append(full_text)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
batch_inputs,
|
|
truncation=True,
|
|
max_length=max_truncation,
|
|
add_special_tokens=False,
|
|
)["input_ids"]
|
|
if image_inputs:
|
|
image_input = image_inputs[0]
|
|
new_image_inputs = {
|
|
"pixel_values": torch.cat(
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
),
|
|
}
|
|
if "pixel_attention_mask" in image_input:
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
)
|
|
if "image_sizes" in image_input:
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
)
|
|
image_inputs = new_image_inputs
|
|
else:
|
|
image_inputs = None
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
|
|
class PaliGemma(VlmCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
speculator: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
self.processor = AutoProcessor.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
super().__init__(
|
|
config_cls=AutoConfig,
|
|
model_cls=PaliGemmaForConditionalGeneration,
|
|
model_id=model_id,
|
|
revision=revision,
|
|
quantize=quantize,
|
|
speculator=speculator,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
@property
|
|
def batch_type(self):
|
|
return PaliGemmaBatch
|
|
|
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
return (
|
|
len(model.text_model.model.layers),
|
|
model.text_model.model.num_key_value_heads,
|
|
model.text_model.model.head_size,
|
|
)
|
|
|
|
def max_past(self) -> Optional[int]:
|
|
return getattr(self.model.text_model, "max_past", None)
|