text-generation-inference/server/text_generation_server/models/vlm_causal_lm.py
drbh 40213c957f
Pali gemma modeling (#1895)
This PR adds paligemma modeling code

Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814

install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf

# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```


basic example sending various requests
```python
from huggingface_hub import InferenceClient

client = InferenceClient("http://127.0.0.1:3000")


images = [
    "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]

prompts = [
    "What animal is in this image?",
    "Name three colors in this image.",
    "What are 10 colors in this image?",
    "Where is the cow standing?",
    "answer en Where is the cow standing?",
    "Is there a bird in the image?",
    "Is ther a cow in the image?",
    "Is there a rabbit in the image?",
    "how many birds are in the image?",
    "how many rabbits are in the image?",
]

for img in images:
    print(f"\nImage: {img.split('/')[-1]}")
    for prompt in prompts:
        inputs = f"![]({img}){prompt}\n"
        json_data = {
            "inputs": inputs,
            "parameters": {
                "max_new_tokens": 30,
                "do_sample": False,
            },
        }
        generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
        print([f"{prompt}\n{generated_output}"])

```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 06:58:47 +02:00

381 lines
14 KiB
Python

import re
import torch
import math
from PIL import Image
from io import BytesIO
import base64
from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string) -> List[Dict[str, str]]:
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append({"type": "text", "content": string[cursor:start]})
parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end()
if cursor != len(string):
parts.append({"type": "text", "content": string[cursor:]})
return parts
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
def load_data_uri(image_uri: str) -> Image.Image:
image_uri = image_uri.split(",")[-1]
content = base64.b64decode(image_uri)
image = Image.open(BytesIO(content))
return image
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
class VlmCausalLM(BaseFlashMistral):
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits