mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* Fix runtime error when Qwen2-VL was prompted with multiple images
Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:
File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 459, in get_position_ids
text_pos_ids = torch.arange(text_length, device=d)
RuntimeError: upper bound and larger bound inconsistent with step sign
The error was caused by text_length variable going to negative value
when multiple images caused multiple loops in the get_position_ids
function's main loop.
The error is a simple logic mistake where next_image_pos is initialized
as relative offset from current_pos, but was used like it was absolute
position from zero.
* Fix runtime error when Qwen2-VL was prompted with multiple images
Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:
File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 534, in forward
inputs_embeds[input_ids == self.image_token_id] = image_embeds
RuntimeError: shape mismatch: value tensor of shape [512, 3584] cannot be broadcast to indexing result of shape [1024, 3584]
(The error message shape numbers can be different depending on the input
image resolutions)
The error was caused by adding the wrong number of <|image_pad|> tokens
to the tokenized input in the image_text_replacement function.
The error is a simple logical mistake where the number of image pad
tokens is checked from pixel_value_shape tensor's first dimension
length. However, the pixel_value_shape contains patches from all of the
images. Therefore the code added the total number of required image pad
tokens for the whole input to each of the images locations. This
resulted to extra image pad tokens to be present in the tokenized input.
The fix was to check the number of required tokens from the
image_grid_thw tensor. The tensor includes grid_t, grid_h, and grid_w
values for each image. grid_t * grid_h * grid_w results to the total
number of patches for the image [1]. The number of required image pad
tokens is number_of_patches // 4.
[1] 31f9a289a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py (L311)
---------
Co-authored-by: Janne Alatalo <janne.alatalo@jamk.fi>
487 lines
19 KiB
Python
487 lines
19 KiB
Python
import torch
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
from opentelemetry import trace
|
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.models.flash_causal_lm import (
|
|
FlashCausalLMBatch,
|
|
FlashCausalLM,
|
|
)
|
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
|
from text_generation_server.utils.log import log_master
|
|
from transformers import AutoProcessor
|
|
from text_generation_server.layers.attention import Seqlen
|
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
|
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
"""
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
|
|
Args:
|
|
image_size (`tuple`):
|
|
The size of the input image in the format (height, width).
|
|
grid_pinpoints (`List`):
|
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
of the form `(height, width)`.
|
|
patch_size (`int`):
|
|
The size of each image patch.
|
|
|
|
Returns:
|
|
tuple: The shape of the image patch grid in the format (width, height).
|
|
"""
|
|
if not isinstance(grid_pinpoints, list):
|
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
|
|
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
return height // patch_size, width // patch_size
|
|
|
|
|
|
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
|
if config.model_type == "idefics2":
|
|
image_seq_len = 64
|
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
|
if processor.image_processor.do_image_splitting:
|
|
image_str *= 5
|
|
return image_str
|
|
elif config.model_type == "llava_next":
|
|
height, width = image_input["image_sizes"][image_id]
|
|
num_features = get_number_of_features(height, width, config)
|
|
from loguru import logger
|
|
|
|
log_master(
|
|
logger.info,
|
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
|
)
|
|
return "<image>" * num_features
|
|
|
|
elif config.model_type == "paligemma":
|
|
return "<image>" * config.text_config.num_image_tokens
|
|
elif config.model_type == "qwen2_vl":
|
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
|
num_pads = grid_t * grid_h * grid_w // 4
|
|
padding = "<|image_pad|>" * num_pads
|
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
|
else:
|
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
|
|
|
|
|
def image_text_replacement_fixup(config, text: str) -> str:
|
|
if config.model_type == "idefics2":
|
|
return text.replace(
|
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
|
)
|
|
return text
|
|
|
|
|
|
def get_unpadded_features(
|
|
original_height: int,
|
|
original_width: int,
|
|
npatches: int,
|
|
num_patch_height: int,
|
|
num_patch_width: int,
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio: float = original_width / original_height
|
|
current_aspect_ratio: float = current_width / current_height
|
|
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (original_height * current_width) // original_width
|
|
padding = (current_height - new_height) // 2
|
|
current_height = current_height - (2 * padding)
|
|
else:
|
|
new_width = (original_width * current_height) // original_height
|
|
padding = (current_width - new_width) // 2
|
|
current_width = current_width - (2 * padding)
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
def get_number_of_features(height: int, width: int, config) -> int:
|
|
# From config
|
|
# Hardcoded for CLIP for now
|
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
|
image_grid_pinpoints = config.image_grid_pinpoints
|
|
image_size = config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
|
|
assert image_size % patch_size == 0
|
|
|
|
npatches = image_size // patch_size
|
|
|
|
# Dimensions are intentionally swapped to be bug-compatible with
|
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
[height, width],
|
|
image_grid_pinpoints,
|
|
image_size,
|
|
)
|
|
unpadded_features, newline_features = get_unpadded_features(
|
|
height, width, npatches, num_patch_height, num_patch_width
|
|
)
|
|
# The base patch covers the entire image
|
|
base_features = npatches**2
|
|
return unpadded_features + newline_features + base_features
|
|
|
|
|
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
|
pixel_values: Optional[List[torch.Tensor]]
|
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
image_sizes: Optional[List[Tuple[int, int]]]
|
|
image_grid_thw: Optional[torch.Tensor]
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches):
|
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
return batch
|
|
|
|
@tracer.start_as_current_span("filter")
|
|
def filter(self, request_ids: List[int]):
|
|
batch = super().filter(request_ids)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
return batch
|
|
|
|
@classmethod
|
|
def batch_tokenized_inputs(
|
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
|
):
|
|
# Process images first. We need all of them so that the processor
|
|
# can make the image splits the same size. And we need the final
|
|
# sizes to insert correct number of image tokens.
|
|
images = []
|
|
for r in requests:
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
pass
|
|
elif chunk_type == "image":
|
|
image = Image.open(BytesIO(chunk.image.data))
|
|
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
|
# default warmup image is 20x20
|
|
if config.model_type == "qwen2_vl":
|
|
if image.width <= 20:
|
|
w = image.width * 2
|
|
h = image.height * 2
|
|
image = image.resize((w, h))
|
|
|
|
if config.model_type == "llava_next":
|
|
images.append(image)
|
|
else:
|
|
images.append([image])
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
if images:
|
|
image_inputs = processor.image_processor(images, return_tensors="pt")
|
|
else:
|
|
image_inputs = None
|
|
|
|
batch_inputs = []
|
|
max_truncation = 0
|
|
image_id = 0
|
|
for r in requests:
|
|
full_text = ""
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
full_text += chunk.text
|
|
elif chunk_type == "image":
|
|
full_text += image_text_replacement(
|
|
processor, image_inputs, config, image_id
|
|
)
|
|
image_id += 1
|
|
|
|
full_text = image_text_replacement_fixup(config, full_text)
|
|
|
|
batch_inputs.append(full_text)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
batch_inputs,
|
|
truncation=True,
|
|
max_length=max_truncation,
|
|
add_special_tokens=not config.model_type == "paligemma",
|
|
)["input_ids"]
|
|
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
@classmethod
|
|
def from_pb_processor(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
processor,
|
|
config,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "VlmCausalLMBatch":
|
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
|
pb.requests, tokenizer, processor, config
|
|
)
|
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
if image_inputs is not None:
|
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
|
if "pixel_attention_mask" in image_inputs:
|
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
|
device=device
|
|
)
|
|
else:
|
|
batch.pixel_attention_mask = None
|
|
if "image_sizes" in image_inputs:
|
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
else:
|
|
batch.image_sizes = None
|
|
if "image_grid_thw" in image_inputs:
|
|
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
|
else:
|
|
batch.image_grid_thw = None
|
|
else:
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
return batch
|
|
|
|
|
|
class VlmCausalLM(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
*,
|
|
processor_class=AutoProcessor,
|
|
processor_kwargs=None,
|
|
batch_class=VlmCausalLMBatch,
|
|
revision,
|
|
trust_remote_code: bool,
|
|
**kwargs,
|
|
):
|
|
if PREFIX_CACHING:
|
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
|
if processor_kwargs is None:
|
|
processor_kwargs = {}
|
|
self.processor = processor_class.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
**processor_kwargs,
|
|
)
|
|
self.batch_class = batch_class
|
|
super().__init__(
|
|
model_id=model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
# FIXME: VLM do not work with context chunking yet
|
|
support_chunking=False,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
|
return self.batch_class
|
|
|
|
def max_past(self) -> Optional[int]:
|
|
return getattr(self.model.text_model, "max_past", None)
|
|
|
|
def forward(
|
|
self,
|
|
batch: VlmCausalLMBatch,
|
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
# Model Forward
|
|
if batch.speculative_ids is not None:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_current_length
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
speculative_ids = batch.speculative_ids
|
|
|
|
B, speculative_length = speculative_ids.shape
|
|
new_length = speculative_length + 1
|
|
new_input_ids = torch.cat(
|
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
).reshape(-1)
|
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
arange_int = arange.to(dtype=torch.int32)
|
|
new_position_ids = (
|
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
).view(-1)
|
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
input_lengths = (
|
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
).view(-1)
|
|
cache_lengths_tensor = (
|
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
|
).reshape(-1)
|
|
|
|
# Add Copy the block tables for all members
|
|
block_tables = (
|
|
block_tables.unsqueeze(1)
|
|
.expand(B, new_length, -1)
|
|
.reshape(B * new_length, -1)
|
|
.contiguous()
|
|
)
|
|
max_s = max_s + speculative_length
|
|
|
|
input_ids = new_input_ids
|
|
position_ids = new_position_ids
|
|
else:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
|
max_s = batch.max_current_length
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
if self.model.config.model_type == "qwen2_vl":
|
|
if position_ids.dim() == 1 and batch.prefilling:
|
|
position_ids = self.model.get_position_ids(
|
|
input_ids, batch.image_grid_thw
|
|
)
|
|
batch.position_ids = position_ids
|
|
|
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
# in a circular buffer mode.
|
|
# This makes sure the max_s for the decode pass is correct.
|
|
max_s = min(self.max_past(), max_s)
|
|
|
|
# Try to find an associated cuda graph
|
|
bs = input_ids.shape[0]
|
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
if sorted_padded_bs:
|
|
# Get associated cuda graph
|
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
else:
|
|
cuda_graph = None
|
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = block_tables_to_ragged(
|
|
block_tables=block_tables,
|
|
input_lengths=batch.input_lengths,
|
|
cache_lengths=batch.cache_lengths,
|
|
input_lengths_tensor=batch.input_lengths_tensor,
|
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
|
max_current_length=batch.max_current_length,
|
|
)
|
|
with self._forward_context(
|
|
block_tables=block_tables,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
input_lengths_tensor=input_lengths,
|
|
cache_lengths_tensor=cache_lengths_tensor,
|
|
):
|
|
seqlen = Seqlen(
|
|
input_lengths=input_lengths,
|
|
cache_lengths=cache_lengths_tensor,
|
|
cu_seqlen_q=cu_seqlen_prefill,
|
|
max_q=batch.max_input_length,
|
|
max_k=batch.max_current_length,
|
|
)
|
|
logits, speculative_logits = self.model.forward(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
seqlen=seqlen,
|
|
max_s=max_s,
|
|
prefill_cache_indices=batch.prefill_cache_indices,
|
|
lm_head_indices=lm_head_indices,
|
|
pixel_values=batch.pixel_values,
|
|
pixel_attention_mask=batch.pixel_attention_mask,
|
|
image_sizes=batch.image_sizes,
|
|
image_grid_thw=batch.image_grid_thw,
|
|
)
|
|
if batch.prefill_cache_indices is not None:
|
|
batch.prefill_cache_indices = None
|
|
if batch.pixel_values is not None:
|
|
batch.pixel_values = None
|
|
if batch.pixel_attention_mask is not None:
|
|
batch.pixel_attention_mask = None
|
|
if batch.image_sizes is not None:
|
|
batch.image_sizes = None
|
|
if batch.image_grid_thw is not None:
|
|
batch.image_grid_thw = None
|
|
return logits, speculative_logits
|
|
|
|
# Copy inputs to the static inputs of the cuda graph
|
|
# Static inputs are potentially padded
|
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = block_tables_to_ragged(
|
|
block_tables=block_tables,
|
|
input_lengths=batch.input_lengths,
|
|
cache_lengths=batch.cache_lengths,
|
|
input_lengths_tensor=batch.input_lengths_tensor,
|
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
|
max_current_length=batch.max_current_length,
|
|
)
|
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
|
else:
|
|
cuda_graph["block_tables"][
|
|
: block_tables.shape[0], : block_tables.shape[1]
|
|
] = block_tables
|
|
|
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
|
# so it doesn't matter if we override it with bogus values.
|
|
cuda_graph["slots"].fill_(0)
|
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
cuda_graph["input_lengths"].zero_()
|
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
cuda_graph["cache_lengths"].zero_()
|
|
cuda_graph["cache_lengths"][
|
|
: cache_lengths_tensor.shape[0]
|
|
] = cache_lengths_tensor
|
|
|
|
with self._forward_context(
|
|
block_tables=cuda_graph["block_tables"],
|
|
cu_seqlen_prefill=None,
|
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
|
state=cuda_graph["state"],
|
|
):
|
|
# Replay the graph
|
|
cuda_graph["graph"].replay()
|
|
|
|
# Slice output to the correct shape
|
|
speculative_logits = (
|
|
cuda_graph["speculative_logits"][:bs]
|
|
if cuda_graph["speculative_logits"] is not None
|
|
else None
|
|
)
|
|
logits = cuda_graph["logits"][:bs]
|
|
return logits, speculative_logits
|