mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
330 lines
12 KiB
Python
330 lines
12 KiB
Python
|
import re
|
||
|
import torch
|
||
|
import math
|
||
|
from PIL import Image
|
||
|
from io import BytesIO
|
||
|
import base64
|
||
|
|
||
|
from opentelemetry import trace
|
||
|
from typing import Optional, Tuple, List, Type, Dict
|
||
|
|
||
|
from transformers import PreTrainedTokenizerBase
|
||
|
from transformers.image_processing_utils import select_best_resolution
|
||
|
from text_generation_server.pb import generate_pb2
|
||
|
from text_generation_server.models.flash_mistral import (
|
||
|
BaseFlashMistral,
|
||
|
FlashMistralBatch,
|
||
|
)
|
||
|
from text_generation_server.models.cache_manager import (
|
||
|
get_cache_manager,
|
||
|
)
|
||
|
|
||
|
tracer = trace.get_tracer(__name__)
|
||
|
|
||
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||
|
|
||
|
|
||
|
def split(string) -> List[Dict[str, str]]:
|
||
|
parts = []
|
||
|
cursor = 0
|
||
|
for pattern in IMAGES.finditer(string):
|
||
|
start = pattern.start()
|
||
|
if start != cursor:
|
||
|
parts.append({"type": "text", "content": string[cursor:start]})
|
||
|
|
||
|
parts.append({"type": "image", "content": pattern.group(1)})
|
||
|
cursor = pattern.end()
|
||
|
|
||
|
if cursor != len(string):
|
||
|
parts.append({"type": "text", "content": string[cursor:]})
|
||
|
|
||
|
return parts
|
||
|
|
||
|
|
||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||
|
"""
|
||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||
|
|
||
|
Args:
|
||
|
image_size (`tuple`):
|
||
|
The size of the input image in the format (width, height).
|
||
|
grid_pinpoints (`List`):
|
||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||
|
of the form `(height, width)`.
|
||
|
patch_size (`int`):
|
||
|
The size of each image patch.
|
||
|
|
||
|
Returns:
|
||
|
tuple: The shape of the image patch grid in the format (width, height).
|
||
|
"""
|
||
|
if not isinstance(grid_pinpoints, list):
|
||
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||
|
|
||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||
|
return height // patch_size, width // patch_size
|
||
|
|
||
|
|
||
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||
|
# From config
|
||
|
# Hardcoded for CLIP for now
|
||
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||
|
image_grid_pinpoints = config.image_grid_pinpoints
|
||
|
image_size = config.vision_config.image_size
|
||
|
patch_size = config.vision_config.patch_size
|
||
|
|
||
|
assert image_size % patch_size == 0
|
||
|
|
||
|
npatches = image_size // patch_size
|
||
|
|
||
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||
|
[height, width],
|
||
|
image_grid_pinpoints,
|
||
|
image_size,
|
||
|
)
|
||
|
|
||
|
height_of_patch = math.ceil(height / width * npatches)
|
||
|
|
||
|
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||
|
# They are only added after width
|
||
|
newline_features = height_of_patch * num_patch_width
|
||
|
# The base patch covers the entire image
|
||
|
base_features = npatches**2
|
||
|
return unpadded_features + newline_features + base_features
|
||
|
|
||
|
|
||
|
def load_data_uri(image_uri: str) -> Image.Image:
|
||
|
image_uri = image_uri.split(",")[-1]
|
||
|
content = base64.b64decode(image_uri)
|
||
|
image = Image.open(BytesIO(content))
|
||
|
return image
|
||
|
|
||
|
|
||
|
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||
|
# assert get_number_of_features(640, 640) == 2928
|
||
|
|
||
|
|
||
|
class VlmCausalLMBatch(FlashMistralBatch):
|
||
|
pixel_values: Optional[List[torch.Tensor]]
|
||
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||
|
|
||
|
@classmethod
|
||
|
@tracer.start_as_current_span("concatenate")
|
||
|
def concatenate(cls, batches):
|
||
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||
|
batch.pixel_values = None
|
||
|
batch.image_sizes = None
|
||
|
return batch
|
||
|
|
||
|
@tracer.start_as_current_span("filter")
|
||
|
def filter(self, request_ids: List[int]):
|
||
|
batch = super().filter(request_ids)
|
||
|
batch.pixel_values = None
|
||
|
batch.image_sizes = None
|
||
|
return batch
|
||
|
|
||
|
@classmethod
|
||
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||
|
batch_inputs = []
|
||
|
image_inputs = []
|
||
|
max_truncation = 0
|
||
|
for r in requests:
|
||
|
chunks = split(r.inputs)
|
||
|
full_text = ""
|
||
|
for chunk in chunks:
|
||
|
if chunk["type"] == "text":
|
||
|
full_text += chunk["content"]
|
||
|
elif chunk["type"] == "image":
|
||
|
image = chunk["content"]
|
||
|
# Should never receive URLs anymore, processing should be done
|
||
|
# On the rust layer.
|
||
|
# This avoid making n queries per TP
|
||
|
# if image.startswith("https://") or image.startswith("http://"):
|
||
|
# image = processor.image_processor.fetch_images(image)
|
||
|
if image.startswith("data:"):
|
||
|
image = load_data_uri(image)
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
"Cannot process input image not starting with data:"
|
||
|
)
|
||
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||
|
height, width = image_input["image_sizes"][0]
|
||
|
num_features = get_number_of_features(height, width, config)
|
||
|
full_text += "<image>" * num_features
|
||
|
image_inputs.append(image_input)
|
||
|
else:
|
||
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||
|
|
||
|
batch_inputs.append(full_text)
|
||
|
max_truncation = max(max_truncation, r.truncate)
|
||
|
|
||
|
batch_tokenized_inputs = tokenizer(
|
||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||
|
)["input_ids"]
|
||
|
if image_inputs:
|
||
|
image_inputs = {
|
||
|
"pixel_values": torch.cat(
|
||
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||
|
),
|
||
|
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||
|
}
|
||
|
else:
|
||
|
image_inputs = None
|
||
|
return batch_tokenized_inputs, image_inputs
|
||
|
|
||
|
@classmethod
|
||
|
def from_pb_processor(
|
||
|
cls,
|
||
|
pb: generate_pb2.Batch,
|
||
|
tokenizer: PreTrainedTokenizerBase,
|
||
|
processor,
|
||
|
config,
|
||
|
dtype: torch.dtype,
|
||
|
device: torch.device,
|
||
|
) -> "VlmCausalLMBatch":
|
||
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||
|
pb.requests, tokenizer, processor, config
|
||
|
)
|
||
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||
|
if image_inputs is not None:
|
||
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||
|
else:
|
||
|
batch.pixel_values = None
|
||
|
batch.image_sizes = None
|
||
|
return batch
|
||
|
|
||
|
|
||
|
class VlmCausalLM(BaseFlashMistral):
|
||
|
@property
|
||
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||
|
return VlmCausalLMBatch
|
||
|
|
||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||
|
return (
|
||
|
len(model.language_model.model.layers),
|
||
|
model.language_model.model.num_key_value_heads,
|
||
|
model.language_model.model.head_size,
|
||
|
)
|
||
|
|
||
|
def max_past(self) -> Optional[int]:
|
||
|
return getattr(self.model.language_model, "max_past", None)
|
||
|
|
||
|
def forward(
|
||
|
self, batch: VlmCausalLMBatch
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
|
# Model Forward
|
||
|
if batch.speculative_ids is not None:
|
||
|
input_ids = batch.input_ids
|
||
|
position_ids = batch.position_ids
|
||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||
|
kv_cache = get_cache_manager().kv_cache
|
||
|
block_tables = batch.block_tables_tensor
|
||
|
slots = batch.slots[batch.slot_indices]
|
||
|
input_lengths = batch.input_lengths_tensor
|
||
|
max_s = batch.max_seqlen
|
||
|
lm_head_indices = batch.prefill_head_indices
|
||
|
|
||
|
speculative_ids = batch.speculative_ids
|
||
|
|
||
|
B, speculative_length = speculative_ids.shape
|
||
|
new_length = speculative_length + 1
|
||
|
new_input_ids = torch.cat(
|
||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||
|
).reshape(-1)
|
||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||
|
arange_int = arange.to(dtype=torch.int32)
|
||
|
new_position_ids = (
|
||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||
|
).view(-1)
|
||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||
|
input_lengths = (
|
||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||
|
).view(-1)
|
||
|
|
||
|
# Add Copy the block tables for all members
|
||
|
block_tables = (
|
||
|
block_tables.unsqueeze(1)
|
||
|
.expand(B, new_length, -1)
|
||
|
.reshape(B * new_length, -1)
|
||
|
.contiguous()
|
||
|
)
|
||
|
max_s = max_s + speculative_length
|
||
|
|
||
|
input_ids = new_input_ids
|
||
|
position_ids = new_position_ids
|
||
|
else:
|
||
|
input_ids = batch.input_ids
|
||
|
position_ids = batch.position_ids
|
||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||
|
kv_cache = get_cache_manager().kv_cache
|
||
|
block_tables = batch.block_tables_tensor
|
||
|
slots = batch.slots[batch.slot_indices]
|
||
|
input_lengths = batch.input_lengths_tensor
|
||
|
max_s = batch.max_seqlen
|
||
|
lm_head_indices = batch.prefill_head_indices
|
||
|
|
||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||
|
# in a circular buffer mode.
|
||
|
# This makes sure the max_s for the decode pass is correct.
|
||
|
max_s = min(self.max_past(), max_s)
|
||
|
|
||
|
bs = input_ids.shape[0]
|
||
|
padded_bs = bs
|
||
|
if bs == 3:
|
||
|
padded_bs = 4
|
||
|
elif 3 < bs <= 8:
|
||
|
padded_bs = 8
|
||
|
elif bs > 8:
|
||
|
padded_bs = (bs + 7) // 8 * 8
|
||
|
|
||
|
# Try to find an associated cuda graph
|
||
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||
|
|
||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||
|
logits, speculative_logits = self.model.forward(
|
||
|
input_ids=input_ids,
|
||
|
position_ids=position_ids,
|
||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||
|
kv_cache=kv_cache,
|
||
|
block_tables=block_tables,
|
||
|
slots=slots,
|
||
|
input_lengths=input_lengths,
|
||
|
max_s=max_s,
|
||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||
|
lm_head_indices=lm_head_indices,
|
||
|
pixel_values=batch.pixel_values,
|
||
|
image_sizes=batch.image_sizes,
|
||
|
)
|
||
|
if batch.prefill_cache_indices is not None:
|
||
|
batch.prefill_cache_indices = None
|
||
|
if batch.pixel_values is not None:
|
||
|
batch.pixel_values = None
|
||
|
if batch.image_sizes is not None:
|
||
|
batch.image_sizes = None
|
||
|
return logits, speculative_logits
|
||
|
|
||
|
# Copy inputs to the static inputs of the cuda graph
|
||
|
# Static inputs are potentially padded
|
||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||
|
cuda_graph["block_tables"][
|
||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||
|
] = block_tables
|
||
|
cuda_graph["slots"].fill_(-1)
|
||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||
|
cuda_graph["input_lengths"].zero_()
|
||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||
|
|
||
|
# Replay the graph
|
||
|
cuda_graph["graph"].replay()
|
||
|
|
||
|
# Slice output to the correct shape
|
||
|
speculative_logits = (
|
||
|
cuda_graph["speculative_logits"][:bs]
|
||
|
if cuda_graph["speculative_logits"] is not None
|
||
|
else None
|
||
|
)
|
||
|
logits = cuda_graph["logits"][:bs]
|
||
|
return logits, speculative_logits
|