import re import torch import math from PIL import Image from io import BytesIO import base64 from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_mistral import ( BaseFlashMistral, FlashMistralBatch, ) from text_generation_server.models.cache_manager import ( get_cache_manager, ) tracer = trace.get_tracer(__name__) IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") def split(string) -> List[Dict[str, str]]: parts = [] cursor = 0 for pattern in IMAGES.finditer(string): start = pattern.start() if start != cursor: parts.append({"type": "text", "content": string[cursor:start]}) parts.append({"type": "image", "content": pattern.group(1)}) cursor = pattern.end() if cursor != len(string): parts.append({"type": "text", "content": string[cursor:]}) return parts def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size def get_number_of_features(height: int, width: int, config) -> int: # From config # Hardcoded for CLIP for now # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] image_grid_pinpoints = config.image_grid_pinpoints image_size = config.vision_config.image_size patch_size = config.vision_config.patch_size assert image_size % patch_size == 0 npatches = image_size // patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( [height, width], image_grid_pinpoints, image_size, ) height_of_patch = math.ceil(height / width * npatches) unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width # They are only added after width newline_features = height_of_patch * num_patch_width # The base patch covers the entire image base_features = npatches**2 return unpadded_features + newline_features + base_features def load_data_uri(image_uri: str) -> Image.Image: image_uri = image_uri.split(",")[-1] content = base64.b64decode(image_uri) image = Image.open(BytesIO(content)) return image # assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}" # assert get_number_of_features(640, 640) == 2928 class VlmCausalLMBatch(FlashMistralBatch): pixel_values: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch.pixel_values = None batch.image_sizes = None return batch @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): batch = super().filter(request_ids) batch.pixel_values = None batch.image_sizes = None return batch @classmethod def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: chunks = split(r.inputs) full_text = "" for chunk in chunks: if chunk["type"] == "text": full_text += chunk["content"] elif chunk["type"] == "image": image = chunk["content"] # Should never receive URLs anymore, processing should be done # On the rust layer. # This avoid making n queries per TP # if image.startswith("https://") or image.startswith("http://"): # image = processor.image_processor.fetch_images(image) if image.startswith("data:"): image = load_data_uri(image) else: raise RuntimeError( "Cannot process input image not starting with data:" ) image_input = processor.image_processor(image, return_tensors="pt") height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config) full_text += "" * num_features image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] if image_inputs: image_inputs = { "pixel_values": torch.cat( [img["pixel_values"] for img in image_inputs], dim=0 ), "image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]), } else: image_inputs = None return batch_tokenized_inputs, image_inputs @classmethod def from_pb_processor( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, processor, config, dtype: torch.dtype, device: torch.device, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.pixel_values = None batch.image_sizes = None return batch class VlmCausalLM(BaseFlashMistral): @property def batch_type(self) -> Type[VlmCausalLMBatch]: return VlmCausalLMBatch def get_layer_config(self, model) -> Tuple[int, int, int]: return ( len(model.language_model.model.layers), model.language_model.model.num_key_value_heads, model.language_model.model.head_size, ) def max_past(self) -> Optional[int]: return getattr(self.model.language_model, "max_past", None) def forward( self, batch: VlmCausalLMBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 new_input_ids = torch.cat( [input_ids.unsqueeze(-1), speculative_ids], dim=1 ).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) # Add Copy the block tables for all members block_tables = ( block_tables.unsqueeze(1) .expand(B, new_length, -1) .reshape(B * new_length, -1) .contiguous() ) max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids else: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) bs = input_ids.shape[0] padded_bs = bs if bs == 3: padded_bs = 4 elif 3 < bs <= 8: padded_bs = 8 elif bs > 8: padded_bs = (bs + 7) // 8 * 8 # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None if batch.image_sizes is not None: batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths # Replay the graph cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits