import torch from PIL import Image from io import BytesIO from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( BaseFlashMistral, ) tracer = trace.get_tracer(__name__) def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size def image_text_replacement(image_input, config, image_id) -> str: if config.model_type == "idefics2": # TODO technically depends on image splitting which is not implemented. num_features = 320 return ( "" + "" * num_features + "" ) elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) from loguru import logger logger.info( f"Found {num_features} features in image of resolution {height}x{width}" ) return "" * num_features elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") def get_unpadded_features( height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width aspect_ratio: float = width / height current_aspect_ratio: float = current_width / current_height if aspect_ratio > current_aspect_ratio: new_height = (height * current_width) // width current_height = new_height else: new_width = (width * current_height) // height current_width = new_width unpadded_features = current_height * current_width newline_features = current_height return (unpadded_features, newline_features) def get_number_of_features(height: int, width: int, config) -> int: # From config # Hardcoded for CLIP for now # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] image_grid_pinpoints = config.image_grid_pinpoints image_size = config.vision_config.image_size patch_size = config.vision_config.patch_size assert image_size % patch_size == 0 npatches = image_size // patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( [height, width], image_grid_pinpoints, image_size, ) unpadded_features, newline_features = get_unpadded_features( height, width, npatches, num_patch_height, num_patch_width ) # The base patch covers the entire image base_features = npatches**2 return unpadded_features + newline_features + base_features class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None return batch @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): batch = super().filter(request_ids) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None return batch @classmethod def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): # Process images first. We need all of them so that the processor # can make the image splits the same size. And we need the final # sizes to insert correct number of image tokens. images = [] for r in requests: for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": pass elif chunk_type == "image": image = Image.open(BytesIO(chunk.image.data)) if config.model_type == "llava_next": images.append(image) else: images.append([image]) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None batch_inputs = [] max_truncation = 0 image_id = 0 for r in requests: full_text = "" for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": full_text += image_text_replacement(image_inputs, config, image_id) image_id += 1 batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation, add_special_tokens=not config.model_type == "paligemma", )["input_ids"] return batch_tokenized_inputs, image_inputs @classmethod def from_pb_processor( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, processor, config, dtype: torch.dtype, device: torch.device, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: batch.pixel_values = image_inputs["pixel_values"].to(device=device) if "pixel_attention_mask" in image_inputs: batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( device=device ) else: batch.pixel_attention_mask = None if "image_sizes" in image_inputs: batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.image_sizes = None else: batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None return batch class VlmCausalLM(BaseFlashMistral): @property def batch_type(self) -> Type[VlmCausalLMBatch]: return VlmCausalLMBatch def forward( self, batch: VlmCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 new_input_ids = torch.cat( [input_ids.unsqueeze(-1), speculative_ids], dim=1 ).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) # Add Copy the block tables for all members block_tables = ( block_tables.unsqueeze(1) .expand(B, new_length, -1) .reshape(B * new_length, -1) .contiguous() ) max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids else: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) if sorted_padded_bs: # Get associated cuda graph cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None if batch.pixel_attention_mask is not None: batch.pixel_attention_mask = None if batch.image_sizes is not None: batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths # Replay the graph cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits