mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
The router will now send the input as chunks besides as a single string. This change modifies the server to process chunked input rather than strings. This also allows us to remove the image extraction code from the server.
339 lines
13 KiB
Python
339 lines
13 KiB
Python
import torch
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
from opentelemetry import trace
|
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
|
from text_generation_server.models.flash_mistral import (
|
|
BaseFlashMistral,
|
|
)
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
"""
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
|
|
Args:
|
|
image_size (`tuple`):
|
|
The size of the input image in the format (width, height).
|
|
grid_pinpoints (`List`):
|
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
of the form `(height, width)`.
|
|
patch_size (`int`):
|
|
The size of each image patch.
|
|
|
|
Returns:
|
|
tuple: The shape of the image patch grid in the format (width, height).
|
|
"""
|
|
if not isinstance(grid_pinpoints, list):
|
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
|
|
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
return height // patch_size, width // patch_size
|
|
|
|
|
|
def image_text_replacement(image_input, config, image_id) -> str:
|
|
if config.model_type == "idefics2":
|
|
# TODO technically depends on image splitting which is not implemented.
|
|
num_features = 320
|
|
return (
|
|
"<fake_token_around_image>"
|
|
+ "<image>" * num_features
|
|
+ "<fake_token_around_image>"
|
|
)
|
|
elif config.model_type == "llava_next":
|
|
height, width = image_input["image_sizes"][image_id]
|
|
num_features = get_number_of_features(height, width, config)
|
|
from loguru import logger
|
|
|
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
|
return "<image>" * num_features
|
|
|
|
elif config.model_type == "paligemma":
|
|
return "<image>" * config.text_config.num_image_tokens
|
|
else:
|
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
|
|
|
|
|
def get_unpadded_features(
|
|
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio: float = width / height
|
|
current_aspect_ratio: float = current_width / current_height
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (height * current_width) // width
|
|
current_height = new_height
|
|
else:
|
|
new_width = (width * current_height) // height
|
|
current_width = new_width
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
def get_number_of_features(height: int, width: int, config) -> int:
|
|
# From config
|
|
# Hardcoded for CLIP for now
|
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
|
image_grid_pinpoints = config.image_grid_pinpoints
|
|
image_size = config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
|
|
assert image_size % patch_size == 0
|
|
|
|
npatches = image_size // patch_size
|
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
[height, width],
|
|
image_grid_pinpoints,
|
|
image_size,
|
|
)
|
|
unpadded_features, newline_features = get_unpadded_features(
|
|
height, width, npatches, num_patch_height, num_patch_width
|
|
)
|
|
# The base patch covers the entire image
|
|
base_features = npatches**2
|
|
return unpadded_features + newline_features + base_features
|
|
|
|
|
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
|
pixel_values: Optional[List[torch.Tensor]]
|
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
image_sizes: Optional[List[Tuple[int, int]]]
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches):
|
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
@tracer.start_as_current_span("filter")
|
|
def filter(self, request_ids: List[int]):
|
|
batch = super().filter(request_ids)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
@classmethod
|
|
def batch_tokenized_inputs(
|
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
|
):
|
|
batch_inputs = []
|
|
image_inputs = []
|
|
max_truncation = 0
|
|
for r in requests:
|
|
full_text = ""
|
|
image_id = 0
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
full_text += chunk.text
|
|
elif chunk_type == "image":
|
|
image = Image.open(BytesIO(chunk.image.data))
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
full_text += image_text_replacement(image_input, config, image_id)
|
|
image_inputs.append(image_input)
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
batch_inputs.append(full_text)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
batch_inputs,
|
|
truncation=True,
|
|
max_length=max_truncation,
|
|
add_special_tokens=not config.model_type == "paligemma",
|
|
)["input_ids"]
|
|
if image_inputs:
|
|
image_input = image_inputs[0]
|
|
new_image_inputs = {
|
|
"pixel_values": torch.cat(
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
),
|
|
}
|
|
if "pixel_attention_mask" in image_input:
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
)
|
|
if "image_sizes" in image_input:
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
)
|
|
image_inputs = new_image_inputs
|
|
else:
|
|
image_inputs = None
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
@classmethod
|
|
def from_pb_processor(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
processor,
|
|
config,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "VlmCausalLMBatch":
|
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
|
pb.requests, tokenizer, processor, config
|
|
)
|
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
if image_inputs is not None:
|
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
|
if "pixel_attention_mask" in image_inputs:
|
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
|
device=device
|
|
)
|
|
else:
|
|
batch.pixel_attention_mask = None
|
|
if "image_sizes" in image_inputs:
|
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
else:
|
|
batch.image_sizes = None
|
|
else:
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
return batch
|
|
|
|
|
|
class VlmCausalLM(BaseFlashMistral):
|
|
@property
|
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
|
return VlmCausalLMBatch
|
|
|
|
def forward(
|
|
self, batch: VlmCausalLMBatch
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
# Model Forward
|
|
if batch.speculative_ids is not None:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_seqlen
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
speculative_ids = batch.speculative_ids
|
|
|
|
B, speculative_length = speculative_ids.shape
|
|
new_length = speculative_length + 1
|
|
new_input_ids = torch.cat(
|
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
).reshape(-1)
|
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
arange_int = arange.to(dtype=torch.int32)
|
|
new_position_ids = (
|
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
).view(-1)
|
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
input_lengths = (
|
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
).view(-1)
|
|
|
|
# Add Copy the block tables for all members
|
|
block_tables = (
|
|
block_tables.unsqueeze(1)
|
|
.expand(B, new_length, -1)
|
|
.reshape(B * new_length, -1)
|
|
.contiguous()
|
|
)
|
|
max_s = max_s + speculative_length
|
|
|
|
input_ids = new_input_ids
|
|
position_ids = new_position_ids
|
|
else:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_seqlen
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
# in a circular buffer mode.
|
|
# This makes sure the max_s for the decode pass is correct.
|
|
max_s = min(self.max_past(), max_s)
|
|
|
|
bs = input_ids.shape[0]
|
|
# Try to find an associated cuda graph
|
|
bs = input_ids.shape[0]
|
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
if sorted_padded_bs:
|
|
# Get associated cuda graph
|
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
else:
|
|
cuda_graph = None
|
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
logits, speculative_logits = self.model.forward(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
input_lengths=input_lengths,
|
|
max_s=max_s,
|
|
prefill_cache_indices=batch.prefill_cache_indices,
|
|
lm_head_indices=lm_head_indices,
|
|
pixel_values=batch.pixel_values,
|
|
pixel_attention_mask=batch.pixel_attention_mask,
|
|
image_sizes=batch.image_sizes,
|
|
)
|
|
if batch.prefill_cache_indices is not None:
|
|
batch.prefill_cache_indices = None
|
|
if batch.pixel_values is not None:
|
|
batch.pixel_values = None
|
|
if batch.pixel_attention_mask is not None:
|
|
batch.pixel_attention_mask = None
|
|
if batch.image_sizes is not None:
|
|
batch.image_sizes = None
|
|
return logits, speculative_logits
|
|
|
|
# Copy inputs to the static inputs of the cuda graph
|
|
# Static inputs are potentially padded
|
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
cuda_graph["block_tables"][
|
|
: block_tables.shape[0], : block_tables.shape[1]
|
|
] = block_tables
|
|
cuda_graph["slots"].fill_(-1)
|
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
cuda_graph["input_lengths"].zero_()
|
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
|
|
# Replay the graph
|
|
cuda_graph["graph"].replay()
|
|
|
|
# Slice output to the correct shape
|
|
speculative_logits = (
|
|
cuda_graph["speculative_logits"][:bs]
|
|
if cuda_graph["speculative_logits"] is not None
|
|
else None
|
|
)
|
|
logits = cuda_graph["logits"][:bs]
|
|
return logits, speculative_logits
|