mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
374 lines
14 KiB
Python
374 lines
14 KiB
Python
import re
|
|
import torch
|
|
import math
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
|
|
from opentelemetry import trace
|
|
from typing import Optional, Tuple, List, Type, Dict
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.models.flash_mistral import (
|
|
BaseFlashMistral,
|
|
FlashMistralBatch,
|
|
)
|
|
from text_generation_server.models.cache_manager import (
|
|
get_cache_manager,
|
|
)
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
|
|
|
|
def split(string) -> List[Dict[str, str]]:
|
|
parts = []
|
|
cursor = 0
|
|
for pattern in IMAGES.finditer(string):
|
|
start = pattern.start()
|
|
if start != cursor:
|
|
parts.append({"type": "text", "content": string[cursor:start]})
|
|
|
|
parts.append({"type": "image", "content": pattern.group(1)})
|
|
cursor = pattern.end()
|
|
|
|
if cursor != len(string):
|
|
parts.append({"type": "text", "content": string[cursor:]})
|
|
|
|
return parts
|
|
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
"""
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
|
|
Args:
|
|
image_size (`tuple`):
|
|
The size of the input image in the format (width, height).
|
|
grid_pinpoints (`List`):
|
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
of the form `(height, width)`.
|
|
patch_size (`int`):
|
|
The size of each image patch.
|
|
|
|
Returns:
|
|
tuple: The shape of the image patch grid in the format (width, height).
|
|
"""
|
|
if not isinstance(grid_pinpoints, list):
|
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
|
|
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
return height // patch_size, width // patch_size
|
|
|
|
|
|
def image_text_replacement(image_input, config, image_id) -> str:
|
|
if config.model_type == "idefics2":
|
|
# TODO technically depends on image splitting which is not implemented.
|
|
num_features = 320
|
|
return (
|
|
"<fake_token_around_image>"
|
|
+ "<image>" * num_features
|
|
+ "<fake_token_around_image>"
|
|
)
|
|
elif config.model_type == "llava_next":
|
|
height, width = image_input["image_sizes"][image_id]
|
|
num_features = get_number_of_features(height, width, config)
|
|
from loguru import logger
|
|
|
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
|
return "<image>" * num_features
|
|
else:
|
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
|
|
|
|
|
def get_unpadded_features(
|
|
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio: float = width / height
|
|
current_aspect_ratio: float = current_width / current_height
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (height * current_width) // width
|
|
current_height = new_height
|
|
else:
|
|
new_width = (width * current_height) // height
|
|
current_width = new_width
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
def get_number_of_features(height: int, width: int, config) -> int:
|
|
# From config
|
|
# Hardcoded for CLIP for now
|
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
|
image_grid_pinpoints = config.image_grid_pinpoints
|
|
image_size = config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
|
|
assert image_size % patch_size == 0
|
|
|
|
npatches = image_size // patch_size
|
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
[height, width],
|
|
image_grid_pinpoints,
|
|
image_size,
|
|
)
|
|
unpadded_features, newline_features = get_unpadded_features(
|
|
height, width, npatches, num_patch_height, num_patch_width
|
|
)
|
|
# The base patch covers the entire image
|
|
base_features = npatches**2
|
|
return unpadded_features + newline_features + base_features
|
|
|
|
|
|
def load_data_uri(image_uri: str) -> Image.Image:
|
|
image_uri = image_uri.split(",")[-1]
|
|
content = base64.b64decode(image_uri)
|
|
image = Image.open(BytesIO(content))
|
|
return image
|
|
|
|
|
|
class VlmCausalLMBatch(FlashMistralBatch):
|
|
pixel_values: Optional[List[torch.Tensor]]
|
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
image_sizes: Optional[List[Tuple[int, int]]]
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches):
|
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
@tracer.start_as_current_span("filter")
|
|
def filter(self, request_ids: List[int]):
|
|
batch = super().filter(request_ids)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
@classmethod
|
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
|
batch_inputs = []
|
|
image_inputs = []
|
|
max_truncation = 0
|
|
for r in requests:
|
|
chunks = split(r.inputs)
|
|
full_text = ""
|
|
image_id = 0
|
|
for chunk in chunks:
|
|
if chunk["type"] == "text":
|
|
full_text += chunk["content"]
|
|
elif chunk["type"] == "image":
|
|
image = chunk["content"]
|
|
# Should never receive URLs anymore, processing should be done
|
|
# On the rust layer.
|
|
# This avoid making n queries per TP
|
|
# if image.startswith("https://") or image.startswith("http://"):
|
|
# image = processor.image_processor.fetch_images(image)
|
|
if image.startswith("data:"):
|
|
image = load_data_uri(image)
|
|
else:
|
|
raise RuntimeError(
|
|
"Cannot process input image not starting with data:"
|
|
)
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
full_text += image_text_replacement(image_input, config, image_id)
|
|
image_inputs.append(image_input)
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
|
|
|
batch_inputs.append(full_text)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
batch_inputs, truncation=True, max_length=max_truncation
|
|
)["input_ids"]
|
|
if image_inputs:
|
|
image_input = image_inputs[0]
|
|
new_image_inputs = {
|
|
"pixel_values": torch.cat(
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
),
|
|
}
|
|
if "pixel_attention_mask" in image_input:
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
)
|
|
if "image_sizes" in image_input:
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
)
|
|
image_inputs = new_image_inputs
|
|
else:
|
|
image_inputs = None
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
@classmethod
|
|
def from_pb_processor(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
processor,
|
|
config,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "VlmCausalLMBatch":
|
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
|
pb.requests, tokenizer, processor, config
|
|
)
|
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
if image_inputs is not None:
|
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
|
if "pixel_attention_mask" in image_inputs:
|
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
|
device=device
|
|
)
|
|
else:
|
|
batch.pixel_attention_mask = None
|
|
if "image_sizes" in image_inputs:
|
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
else:
|
|
batch.image_sizes = None
|
|
else:
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
|
|
class VlmCausalLM(BaseFlashMistral):
|
|
@property
|
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
|
return VlmCausalLMBatch
|
|
|
|
def forward(
|
|
self, batch: VlmCausalLMBatch
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
# Model Forward
|
|
if batch.speculative_ids is not None:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = get_cache_manager().kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_seqlen
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
speculative_ids = batch.speculative_ids
|
|
|
|
B, speculative_length = speculative_ids.shape
|
|
new_length = speculative_length + 1
|
|
new_input_ids = torch.cat(
|
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
).reshape(-1)
|
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
arange_int = arange.to(dtype=torch.int32)
|
|
new_position_ids = (
|
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
).view(-1)
|
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
input_lengths = (
|
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
).view(-1)
|
|
|
|
# Add Copy the block tables for all members
|
|
block_tables = (
|
|
block_tables.unsqueeze(1)
|
|
.expand(B, new_length, -1)
|
|
.reshape(B * new_length, -1)
|
|
.contiguous()
|
|
)
|
|
max_s = max_s + speculative_length
|
|
|
|
input_ids = new_input_ids
|
|
position_ids = new_position_ids
|
|
else:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = get_cache_manager().kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_seqlen
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
# in a circular buffer mode.
|
|
# This makes sure the max_s for the decode pass is correct.
|
|
max_s = min(self.max_past(), max_s)
|
|
|
|
bs = input_ids.shape[0]
|
|
# Try to find an associated cuda graph
|
|
bs = input_ids.shape[0]
|
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
if sorted_padded_bs:
|
|
# Get associated cuda graph
|
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
else:
|
|
cuda_graph = None
|
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
logits, speculative_logits = self.model.forward(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
input_lengths=input_lengths,
|
|
max_s=max_s,
|
|
prefill_cache_indices=batch.prefill_cache_indices,
|
|
lm_head_indices=lm_head_indices,
|
|
pixel_values=batch.pixel_values,
|
|
pixel_attention_mask=batch.pixel_attention_mask,
|
|
image_sizes=batch.image_sizes,
|
|
)
|
|
if batch.prefill_cache_indices is not None:
|
|
batch.prefill_cache_indices = None
|
|
if batch.pixel_values is not None:
|
|
batch.pixel_values = None
|
|
if batch.pixel_attention_mask is not None:
|
|
batch.pixel_attention_mask = None
|
|
if batch.image_sizes is not None:
|
|
batch.image_sizes = None
|
|
return logits, speculative_logits
|
|
|
|
# Copy inputs to the static inputs of the cuda graph
|
|
# Static inputs are potentially padded
|
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
cuda_graph["block_tables"][
|
|
: block_tables.shape[0], : block_tables.shape[1]
|
|
] = block_tables
|
|
cuda_graph["slots"].fill_(-1)
|
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
cuda_graph["input_lengths"].zero_()
|
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
|
|
# Replay the graph
|
|
cuda_graph["graph"].replay()
|
|
|
|
# Slice output to the correct shape
|
|
speculative_logits = (
|
|
cuda_graph["speculative_logits"][:bs]
|
|
if cuda_graph["speculative_logits"] is not None
|
|
else None
|
|
)
|
|
logits = cuda_graph["logits"][:bs]
|
|
return logits, speculative_logits
|