mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 09:42:09 +00:00
* add logic * working * add encoder cache free * fixes * fix idefics * update pixel_values * add improvements * add improvements * improve * nit * fix inputs_embeds * nit * optimizations * add prometheus port * rename vars * rename vars * nit * disable chunking for qwen * review comments * remove port * improve headdim * remove kwargs and redundant args * fix qwen2_5 * fix config image_token_id error * fix test * update paligemma * fix paligemma text * minor fix * fix qwen test * fix qwen test
1134 lines
41 KiB
Python
1134 lines
41 KiB
Python
from dataclasses import dataclass
|
|
import torch
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
from opentelemetry import trace
|
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.models.flash_causal_lm import (
|
|
FlashCausalLMBatch,
|
|
FlashCausalLM,
|
|
)
|
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL
|
|
from loguru import logger
|
|
from text_generation_server.utils.log import log_master
|
|
from transformers import AutoProcessor
|
|
from text_generation_server.layers.attention import Seqlen
|
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
|
|
|
IDEFICS3_IMAGE_TOKEN = "<image>"
|
|
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
|
|
|
|
|
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
|
|
"""
|
|
Create a structured string representation of image tokens
|
|
|
|
Args:
|
|
num_patches: Number of patches in the image
|
|
|
|
Returns:
|
|
String with appropriate image tokens
|
|
"""
|
|
img_string = "<|image_start|>"
|
|
ratio_h, ratio_w = aspect_ratio
|
|
if ratio_h * ratio_w > 1:
|
|
for yy in range(ratio_h):
|
|
for xx in range(ratio_w):
|
|
img_string += "<|patch|>" * num_patches_per_chunk
|
|
if xx < ratio_w - 1:
|
|
img_string += "<|tile_x_separator|>"
|
|
|
|
img_string += "<|tile_y_separator|>"
|
|
img_string += "<|image|>"
|
|
img_string += "<|patch|>" * num_patches_per_chunk
|
|
img_string += "<|image_end|>"
|
|
|
|
return img_string
|
|
|
|
|
|
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
|
def _prompt_split_image(
|
|
*,
|
|
image_seq_len: int,
|
|
image_rows: int,
|
|
image_cols: int,
|
|
fake_token_around_image: str,
|
|
image_token: str,
|
|
global_img_token: str,
|
|
):
|
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
|
text_split_images = ""
|
|
for n_h in range(image_rows):
|
|
for n_w in range(image_cols):
|
|
text_split_images += (
|
|
f"{fake_token_around_image}"
|
|
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
|
+ f"{image_token}" * image_seq_len
|
|
)
|
|
text_split_images += "\n"
|
|
|
|
text_split_images += (
|
|
f"\n{fake_token_around_image}"
|
|
+ f"{global_img_token}"
|
|
+ f"{image_token}" * image_seq_len
|
|
+ f"{fake_token_around_image}"
|
|
)
|
|
return text_split_images
|
|
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
"""
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
|
|
|
Args:
|
|
image_size (`tuple`):
|
|
The size of the input image in the format (height, width).
|
|
grid_pinpoints (`List`):
|
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
|
of the form `(height, width)`.
|
|
patch_size (`int`):
|
|
The size of each image patch.
|
|
|
|
Returns:
|
|
tuple: The shape of the image patch grid in the format (width, height).
|
|
"""
|
|
if not isinstance(grid_pinpoints, list):
|
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
|
|
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
return height // patch_size, width // patch_size
|
|
|
|
|
|
def image_text_replacement(processor, image_input, config) -> str:
|
|
if config.model_type == "idefics2":
|
|
image_seq_len = 64
|
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
|
if processor.image_processor.do_image_splitting:
|
|
image_str *= 5
|
|
return image_str, IDEFICS2_FAKE_TOKEN
|
|
if config.model_type == "idefics3":
|
|
# TODO: implement this in a more general way
|
|
n_rows = image_input["rows"][0][0]
|
|
n_cols = image_input["cols"][0][0]
|
|
image_seq_len = int(
|
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
|
/ (config.scale_factor**2)
|
|
)
|
|
image_str = _prompt_split_image(
|
|
image_seq_len=image_seq_len,
|
|
image_rows=n_rows,
|
|
image_cols=n_cols,
|
|
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
|
)
|
|
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
|
elif config.model_type == "llava_next":
|
|
height, width = image_input["image_sizes"][0]
|
|
num_features = get_number_of_features(height, width, config)
|
|
|
|
log_master(
|
|
logger.info,
|
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
|
)
|
|
return "<image>" * num_features, "<image>"
|
|
|
|
elif config.model_type == "paligemma":
|
|
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
|
elif config.model_type == "qwen2_vl":
|
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
|
num_pads = grid_t * grid_h * grid_w // 4
|
|
padding = "<|image_pad|>" * num_pads
|
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
|
elif config.model_type == "qwen2_5_vl":
|
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
|
num_pads = grid_t * grid_h * grid_w // 4
|
|
padding = "<|image_pad|>" * num_pads
|
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
|
elif config.model_type == "gemma3":
|
|
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
|
# and calculating the number of image tokens
|
|
num_pads = 256
|
|
padding = "<image_soft_token>" * num_pads
|
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
|
|
elif config.model_type == "llama4":
|
|
patch_size = config.vision_config.patch_size
|
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
|
aspect_ratios = image_input["aspect_ratios"][0]
|
|
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
|
|
|
|
num_patches_per_chunk = int(
|
|
(image_height // patch_size)
|
|
* (image_width // patch_size)
|
|
// downsample_ratio
|
|
)
|
|
tokens_for_this_image = prompt_split_image_llama4(
|
|
aspect_ratios, num_patches_per_chunk
|
|
)
|
|
|
|
return tokens_for_this_image, "<|image_start|>"
|
|
else:
|
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
|
|
|
|
|
def image_text_replacement_fixup(config, text: str) -> str:
|
|
if config.model_type == "idefics2":
|
|
return text.replace(
|
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
|
)
|
|
return text
|
|
|
|
|
|
def preprocess_text(config, text: str) -> str:
|
|
if config.model_type == "paligemma":
|
|
return "<bos>" + text + "\n"
|
|
return text
|
|
|
|
|
|
def preprocess_image(config, img):
|
|
model_type = config.model_type
|
|
|
|
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
|
img = img.resize((img.width * 2, img.height * 2))
|
|
if model_type == "paligemma":
|
|
img = img.convert("RGB")
|
|
|
|
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
|
# TODO: check if this is needed
|
|
img = [img]
|
|
|
|
return img
|
|
|
|
|
|
def get_unpadded_features(
|
|
original_height: int,
|
|
original_width: int,
|
|
npatches: int,
|
|
num_patch_height: int,
|
|
num_patch_width: int,
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio: float = original_width / original_height
|
|
current_aspect_ratio: float = current_width / current_height
|
|
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (original_height * current_width) // original_width
|
|
padding = (current_height - new_height) // 2
|
|
current_height = current_height - (2 * padding)
|
|
else:
|
|
new_width = (original_width * current_height) // original_height
|
|
padding = (current_width - new_width) // 2
|
|
current_width = current_width - (2 * padding)
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
def get_number_of_features(height: int, width: int, config) -> int:
|
|
# From config
|
|
# Hardcoded for CLIP for now
|
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
|
image_grid_pinpoints = config.image_grid_pinpoints
|
|
image_size = config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
|
|
assert image_size % patch_size == 0
|
|
|
|
npatches = image_size // patch_size
|
|
|
|
# Dimensions are intentionally swapped to be bug-compatible with
|
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
[height, width],
|
|
image_grid_pinpoints,
|
|
image_size,
|
|
)
|
|
unpadded_features, newline_features = get_unpadded_features(
|
|
height, width, npatches, num_patch_height, num_patch_width
|
|
)
|
|
# The base patch covers the entire image
|
|
base_features = npatches**2
|
|
return unpadded_features + newline_features + base_features
|
|
|
|
|
|
def scatter_image_embeds(
|
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
|
) -> torch.Tensor:
|
|
if is_embed is None:
|
|
return embeds
|
|
|
|
placeholders = embeds.new_full(
|
|
(is_embed.shape[0], embeds.shape[-1]),
|
|
fill_value=torch.nan,
|
|
)
|
|
placeholders[is_embed] = embeds
|
|
return placeholders
|
|
|
|
|
|
def gather_image_embeds(
|
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
|
) -> Optional[torch.Tensor]:
|
|
if is_embed is None:
|
|
return embeds
|
|
sel = embeds[is_embed]
|
|
return sel if sel.numel() else None
|
|
|
|
|
|
@dataclass
|
|
class ImagePositions:
|
|
offset: int
|
|
length: int
|
|
id: int
|
|
num_placeholder_tokens: int
|
|
is_embed: Optional[torch.Tensor] = None
|
|
|
|
|
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
|
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
|
|
image_positions: Optional[List[List[ImagePositions]]]
|
|
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
|
|
pixel_values: Optional[List[torch.Tensor]]
|
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
image_sizes: Optional[List[Tuple[int, int]]]
|
|
image_grid_thw: Optional[torch.Tensor]
|
|
cache_entries_to_free: List[Tuple[int, int]]
|
|
has_image_inputs: bool = False
|
|
inputs_embeds: Optional[torch.Tensor] = None
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches):
|
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
|
|
|
batch.image_inputs = []
|
|
batch.image_positions = []
|
|
batch.encoder_cache = []
|
|
for b in batches:
|
|
if b.image_inputs is not None:
|
|
batch.image_inputs.extend(b.image_inputs)
|
|
else:
|
|
batch.image_inputs.append(None)
|
|
if b.image_positions is not None:
|
|
batch.image_positions.extend(b.image_positions)
|
|
else:
|
|
batch.image_positions.append(None)
|
|
if b.encoder_cache is not None:
|
|
batch.encoder_cache.extend(b.encoder_cache)
|
|
else:
|
|
batch.encoder_cache.append(None)
|
|
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
batch.inputs_embeds = None
|
|
|
|
# To be filled in prepare_for_prefill
|
|
batch.has_image_inputs = False
|
|
batch.cache_entries_to_free = []
|
|
|
|
return batch
|
|
|
|
@tracer.start_as_current_span("filter")
|
|
def filter(self, request_ids: List[int]):
|
|
if len(request_ids) == 0:
|
|
raise ValueError("Batch must have at least one request")
|
|
|
|
image_inputs = []
|
|
image_positions = []
|
|
encoder_cache = []
|
|
|
|
for request_id in request_ids:
|
|
idx = self.requests_idx_mapping[request_id]
|
|
image_inputs.append(self.image_inputs[idx])
|
|
image_positions.append(self.image_positions[idx])
|
|
encoder_cache.append(self.encoder_cache[idx])
|
|
|
|
batch = super().filter(request_ids)
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
batch.inputs_embeds = None
|
|
|
|
batch.image_inputs = image_inputs
|
|
batch.image_positions = image_positions
|
|
batch.encoder_cache = encoder_cache
|
|
|
|
# To be filled in prepare_for_prefill
|
|
batch.has_image_inputs = False
|
|
batch.cache_entries_to_free = []
|
|
return batch
|
|
|
|
@classmethod
|
|
def batch_tokenized_inputs(
|
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
|
):
|
|
kwargs = {}
|
|
if (
|
|
hasattr(processor, "image_processor_class")
|
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
|
):
|
|
kwargs["return_row_col_info"] = True
|
|
|
|
max_length = 0
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
if not hasattr(config, "image_token_index"):
|
|
config.image_token_index = config.image_token_id
|
|
|
|
batch_tokenized_inputs: List[List[int]] = []
|
|
batch_image_inputs: List[Optional[List[dict]]] = []
|
|
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
|
|
|
for r in requests:
|
|
text_parts = []
|
|
image_inputs = []
|
|
image_texts = []
|
|
|
|
image_id = 0
|
|
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
text = preprocess_text(config, chunk.text)
|
|
text_parts.append(text)
|
|
elif chunk_type == "image":
|
|
img = Image.open(BytesIO(chunk.image.data))
|
|
img = preprocess_image(config, img)
|
|
|
|
image_input = processor.image_processor(
|
|
[img], return_tensors="pt", **kwargs
|
|
)
|
|
image_inputs.append(image_input)
|
|
|
|
img_text, img_start_token_str = image_text_replacement(
|
|
processor, image_input, config
|
|
)
|
|
text_parts.append(img_text)
|
|
|
|
image_texts.append([image_id, img_start_token_str, img_text])
|
|
image_id += 1
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
|
input_ids = tokenizer(
|
|
full_text,
|
|
truncation=True,
|
|
max_length=r.truncate,
|
|
add_special_tokens=(
|
|
r.add_special_tokens if config.model_type != "paligemma" else False
|
|
),
|
|
)["input_ids"]
|
|
max_length = max(max_length, len(input_ids))
|
|
|
|
if len(image_inputs) > 0:
|
|
img_start_token = vocab[image_texts[0][1]]
|
|
image_positions = cls.get_image_positions(
|
|
input_ids, image_texts, img_start_token, config, tokenizer
|
|
)
|
|
else:
|
|
image_inputs = None
|
|
image_positions = None
|
|
|
|
batch_tokenized_inputs.append(input_ids)
|
|
batch_image_inputs.append(image_inputs)
|
|
batch_image_positions.append(image_positions)
|
|
|
|
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
|
|
|
@classmethod
|
|
def get_image_positions(
|
|
cls,
|
|
input_ids: List[int],
|
|
image_texts: List[Tuple[int, str, str]],
|
|
img_start_token: int,
|
|
config,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
) -> List[ImagePositions]:
|
|
image_positions = []
|
|
num_images = len(image_texts)
|
|
|
|
input_ids_t = torch.as_tensor(input_ids)
|
|
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
|
|
num_tokens = input_ids_t.numel()
|
|
|
|
last_pos = 0
|
|
for i in range(num_images):
|
|
image_id, img_start_token_str, img_text = image_texts[i]
|
|
img_text = image_text_replacement_fixup(config, img_text)
|
|
|
|
if config.model_type == "gemma3":
|
|
img_text = img_text.replace("\n\n", "")
|
|
|
|
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
|
|
"input_ids"
|
|
][0]
|
|
length = tokens.numel()
|
|
|
|
assert (
|
|
length <= num_tokens
|
|
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
|
|
|
|
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
|
index = img_start_token_pos[pos]
|
|
assert torch.equal(
|
|
input_ids_t[index : index + length], tokens
|
|
), "Image tokens not found in input_ids"
|
|
|
|
is_embed = tokens == config.image_token_index
|
|
num_placeholder_tokens = int(is_embed.sum())
|
|
if num_placeholder_tokens == length:
|
|
is_embed = None
|
|
|
|
pos = ImagePositions(
|
|
offset=index,
|
|
length=length,
|
|
id=image_id,
|
|
num_placeholder_tokens=num_placeholder_tokens,
|
|
is_embed=is_embed,
|
|
)
|
|
|
|
image_positions.append(pos)
|
|
last_pos = index + length
|
|
|
|
if (
|
|
config.model_type == "idefics2"
|
|
and i + 1 != num_images
|
|
and input_ids[last_pos] == config.image_token_index
|
|
):
|
|
fake_token = last_pos - 1
|
|
fake_token_index = torch.searchsorted(
|
|
img_start_token_pos, fake_token, right=False
|
|
)
|
|
img_start_token_pos[fake_token_index] = last_pos
|
|
image_texts[i + 1][2] = image_texts[i + 1][2][
|
|
len(img_start_token_str) :
|
|
]
|
|
|
|
return image_positions
|
|
|
|
@classmethod
|
|
def from_pb_processor(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
processor,
|
|
config,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "VlmCausalLMBatch":
|
|
batch_tokenized_inputs, image_inputs, image_positions = (
|
|
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
|
)
|
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
batch.image_inputs = image_inputs
|
|
batch.image_positions = image_positions
|
|
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
|
|
if len(image_inputs):
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.image_grid_thw = None
|
|
return batch
|
|
|
|
def prepare_for_prefill(self):
|
|
super().prepare_for_prefill()
|
|
|
|
self.has_image_inputs = False
|
|
self.cache_entries_to_free = []
|
|
|
|
self.pixel_values = []
|
|
|
|
assert (
|
|
len(self.cache_lengths)
|
|
== len(self.input_lengths)
|
|
== len(self.prefilling_mask)
|
|
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
|
|
|
|
for i, (
|
|
cache_length,
|
|
input_length,
|
|
request_prefilling,
|
|
) in enumerate(
|
|
zip(
|
|
self.cache_lengths,
|
|
self.input_lengths,
|
|
self.prefilling_mask,
|
|
)
|
|
):
|
|
if not request_prefilling or self.image_positions[i] is None:
|
|
continue
|
|
|
|
for image_position in self.image_positions[i]:
|
|
if image_position is None:
|
|
continue
|
|
start_pos = image_position.offset
|
|
length = image_position.length
|
|
|
|
if start_pos >= cache_length + input_length:
|
|
# No encoder input required at this step
|
|
break
|
|
if start_pos + length <= cache_length:
|
|
# The encode input is already processed
|
|
continue
|
|
|
|
self.has_image_inputs = True
|
|
|
|
if image_position.id not in self.encoder_cache[i]:
|
|
image_inputs = self.image_inputs[i][image_position.id]
|
|
self.pixel_values.append((i, image_position.id, image_inputs))
|
|
|
|
# Remove the image from the image_inputs
|
|
self.image_inputs[i][image_position.id] = None
|
|
|
|
if not self.has_image_inputs:
|
|
self.pixel_values = None
|
|
self.pixel_attention_mask = None
|
|
self.image_sizes = None
|
|
self.image_grid_thw = None
|
|
else:
|
|
image_grid_thw_list = [
|
|
x[2]["image_grid_thw"]
|
|
for x in self.pixel_values
|
|
if "image_grid_thw" in x[2]
|
|
]
|
|
if image_grid_thw_list:
|
|
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
|
|
self.input_ids.device
|
|
)
|
|
else:
|
|
self.image_grid_thw = None
|
|
|
|
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
|
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
|
encoder_outputs, img_pos.is_embed
|
|
)
|
|
|
|
def gather_vision_embeds(self):
|
|
device = self.input_ids.device
|
|
chunks = []
|
|
for (
|
|
i,
|
|
cache_length,
|
|
input_length,
|
|
request_prefilling,
|
|
) in zip(
|
|
range(len(self.requests)),
|
|
self.cache_lengths,
|
|
self.input_lengths,
|
|
self.prefilling_mask,
|
|
):
|
|
if not request_prefilling or self.image_positions[i] is None:
|
|
continue
|
|
|
|
for image_position in self.image_positions[i]:
|
|
if image_position is None:
|
|
continue
|
|
start_pos = image_position.offset
|
|
length = image_position.length
|
|
|
|
if start_pos >= cache_length + input_length:
|
|
# No encoder input required at this step
|
|
break
|
|
if start_pos + length <= cache_length:
|
|
# The encode input is already processed
|
|
continue
|
|
|
|
start_idx = max(cache_length - start_pos, 0)
|
|
end_idx = min(cache_length - start_pos + input_length, length)
|
|
|
|
assert (
|
|
image_position.id in self.encoder_cache[i]
|
|
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
|
|
encoder_output = self.encoder_cache[i][image_position.id]
|
|
|
|
is_embed = image_position.is_embed
|
|
if is_embed is not None:
|
|
is_embed = is_embed[start_idx:end_idx]
|
|
|
|
from loguru import logger
|
|
|
|
logger.info(
|
|
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
|
)
|
|
|
|
embeds = gather_image_embeds(
|
|
encoder_output[start_idx:end_idx],
|
|
is_embed=is_embed,
|
|
)
|
|
if embeds is not None:
|
|
chunks.append(embeds)
|
|
|
|
if end_idx == length:
|
|
self.cache_entries_to_free.append((i, image_position.id))
|
|
self.image_positions[i][image_position.id] = None
|
|
|
|
if len(chunks) == 0:
|
|
return None
|
|
return torch.cat(chunks, dim=0).to(device)
|
|
|
|
def free_encoder_cache(self):
|
|
for i, image_id in self.cache_entries_to_free:
|
|
self.encoder_cache[i].pop(image_id, None)
|
|
|
|
self.cache_entries_to_free = []
|
|
|
|
# release any freed GPU memory immediately?
|
|
|
|
|
|
class VlmCausalLM(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
*,
|
|
processor_class=AutoProcessor,
|
|
processor_kwargs=None,
|
|
batch_class=VlmCausalLMBatch,
|
|
revision,
|
|
trust_remote_code: bool,
|
|
support_chunking: bool = True,
|
|
**kwargs,
|
|
):
|
|
if PREFIX_CACHING:
|
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
|
if processor_kwargs is None:
|
|
processor_kwargs = {}
|
|
self.processor = processor_class.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
**processor_kwargs,
|
|
)
|
|
self.batch_class = batch_class
|
|
super().__init__(
|
|
model_id=model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
support_chunking=support_chunking,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
|
return self.batch_class
|
|
|
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
|
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
|
input_lengths = [max_s] * bs
|
|
cache_lengths = [0] * bs
|
|
config = getattr(self.model.config, "text_config", self.model.config)
|
|
if max_bs is None:
|
|
inputs_embeds = torch.zeros(
|
|
(bs, config.hidden_size),
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
config = getattr(self.model, "config", None)
|
|
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
|
if ( # mrope have position_ids per section, if so repeat n times
|
|
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
|
):
|
|
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
|
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
|
input_lengths_tensor = (
|
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
|
)
|
|
cache_lengths_tensor = torch.zeros(
|
|
bs, dtype=torch.int32, device=self.device
|
|
)
|
|
block_tables = torch.arange(
|
|
max_bt, dtype=torch.int32, device=self.device
|
|
).repeat(bs)
|
|
block_tables = block_tables.reshape((bs, max_bt))
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = block_tables_to_ragged(
|
|
block_tables=block_tables,
|
|
input_lengths=input_lengths,
|
|
cache_lengths=cache_lengths,
|
|
input_lengths_tensor=input_lengths_tensor,
|
|
cache_lengths_tensor=cache_lengths_tensor,
|
|
max_current_length=max_s,
|
|
)
|
|
else:
|
|
if bs > max_bs:
|
|
raise RuntimeError(
|
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
|
)
|
|
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
|
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
|
else:
|
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
|
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
|
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
|
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
|
|
|
if ATTENTION == "flashinfer":
|
|
from text_generation_server.layers.attention.flashinfer import (
|
|
create_decode_state_cuda_graphs,
|
|
)
|
|
|
|
block_tables_ptr = torch.zeros(
|
|
bs + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
|
state = create_decode_state_cuda_graphs(
|
|
device=inputs_embeds.device,
|
|
block_tables=block_tables,
|
|
block_tables_ptr=block_tables_ptr,
|
|
last_page_len=last_page_len,
|
|
num_heads=self.num_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
)
|
|
else:
|
|
state = None
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
self.cuda_graphs[bs] = {
|
|
"inputs_embeds": inputs_embeds,
|
|
"position_ids": position_ids,
|
|
"kv_cache": self.kv_cache,
|
|
"block_tables": block_tables,
|
|
"slots": slots,
|
|
"input_lengths": input_lengths_tensor,
|
|
"cache_lengths": cache_lengths_tensor,
|
|
"state": state,
|
|
"graph": graph,
|
|
}
|
|
|
|
torch.cuda.synchronize()
|
|
# Run once outside to warmup
|
|
with self._forward_context(
|
|
block_tables=block_tables,
|
|
cu_seqlen_prefill=None,
|
|
input_lengths_tensor=input_lengths_tensor,
|
|
state=state,
|
|
cache_lengths_tensor=cache_lengths_tensor,
|
|
):
|
|
seqlen = Seqlen(
|
|
input_lengths=input_lengths_tensor,
|
|
cache_lengths=cache_lengths_tensor,
|
|
cu_seqlen_q=None,
|
|
max_q=1,
|
|
max_k=max_s,
|
|
)
|
|
self.model.forward(
|
|
inputs_embeds=inputs_embeds,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=None,
|
|
kv_cache=self.kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
seqlen=seqlen,
|
|
max_s=max_s,
|
|
prefill_cache_indices=None,
|
|
lm_head_indices=None,
|
|
)
|
|
del seqlen
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
|
seqlen = Seqlen(
|
|
input_lengths=input_lengths_tensor,
|
|
cache_lengths=cache_lengths_tensor,
|
|
cu_seqlen_q=None,
|
|
max_q=1,
|
|
max_k=max_s,
|
|
)
|
|
logits, speculative_logits = self.model.forward(
|
|
inputs_embeds=inputs_embeds,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=None,
|
|
kv_cache=self.kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
seqlen=seqlen,
|
|
max_s=max_s,
|
|
prefill_cache_indices=None,
|
|
lm_head_indices=None,
|
|
)
|
|
self.cuda_graphs[bs]["logits"] = logits
|
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
|
torch.cuda.synchronize()
|
|
|
|
def get_vision_embeds(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
pixel_attention_mask: torch.Tensor,
|
|
image_sizes: torch.Tensor,
|
|
image_grid_thw: torch.Tensor,
|
|
):
|
|
embeds = self.model.get_vision_embeds(
|
|
pixel_values=pixel_values,
|
|
pixel_attention_mask=pixel_attention_mask,
|
|
image_sizes=image_sizes,
|
|
image_grid_thw=image_grid_thw,
|
|
)
|
|
return embeds
|
|
|
|
def get_inputs_embeds(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
vision_embeds: Optional[torch.Tensor] = None,
|
|
):
|
|
return self.model.get_inputs_embeds(
|
|
input_ids=input_ids,
|
|
vision_embeds=vision_embeds,
|
|
)
|
|
|
|
def encode_images(self, batch):
|
|
if batch.pixel_values is not None:
|
|
device = batch.input_ids.device
|
|
for request_id, image_id, image_input in batch.pixel_values:
|
|
pixel_values = image_input["pixel_values"].to(device)
|
|
|
|
if "pixel_attention_mask" in image_input:
|
|
pixel_attention_mask = image_input["pixel_attention_mask"].to(
|
|
device
|
|
)
|
|
else:
|
|
pixel_attention_mask = None
|
|
|
|
if "image_sizes" in image_input:
|
|
image_sizes = image_input["image_sizes"].to(device)
|
|
else:
|
|
image_sizes = None
|
|
|
|
if "image_grid_thw" in image_input:
|
|
image_grid_thw = image_input["image_grid_thw"].to(device)
|
|
else:
|
|
image_grid_thw = None
|
|
|
|
encoder_outputs = self.get_vision_embeds(
|
|
pixel_values=pixel_values,
|
|
pixel_attention_mask=pixel_attention_mask,
|
|
image_sizes=image_sizes,
|
|
image_grid_thw=image_grid_thw,
|
|
)
|
|
batch.update_encoder_cache(
|
|
encoder_outputs,
|
|
request_id,
|
|
batch.image_positions[request_id][image_id],
|
|
)
|
|
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
|
|
def set_inputs_embeds(self, batch):
|
|
if batch.has_image_inputs:
|
|
self.encode_images(batch)
|
|
vision_embeds = batch.gather_vision_embeds()
|
|
batch.has_image_inputs = False
|
|
else:
|
|
vision_embeds = None
|
|
|
|
inputs_embeds = self.get_inputs_embeds(
|
|
batch.input_ids, vision_embeds=vision_embeds
|
|
)
|
|
|
|
batch.inputs_embeds = inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
batch: VlmCausalLMBatch,
|
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
# Model Forward
|
|
if batch.speculative_ids is not None:
|
|
input_ids = batch.input_ids
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
max_s = batch.max_current_length
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
speculative_ids = batch.speculative_ids
|
|
|
|
B, speculative_length = speculative_ids.shape
|
|
new_length = speculative_length + 1
|
|
new_input_ids = torch.cat(
|
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
).reshape(-1)
|
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
arange_int = arange.to(dtype=torch.int32)
|
|
new_position_ids = (
|
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
).view(-1)
|
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
input_lengths = (
|
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
).view(-1)
|
|
cache_lengths_tensor = (
|
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
|
).reshape(-1)
|
|
|
|
# Add Copy the block tables for all members
|
|
block_tables = (
|
|
block_tables.unsqueeze(1)
|
|
.expand(B, new_length, -1)
|
|
.reshape(B * new_length, -1)
|
|
.contiguous()
|
|
)
|
|
max_s = max_s + speculative_length
|
|
|
|
input_ids = new_input_ids
|
|
position_ids = new_position_ids
|
|
else:
|
|
input_ids = batch.input_ids
|
|
inputs_embeds = batch.inputs_embeds
|
|
position_ids = batch.position_ids
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
kv_cache = self.kv_cache
|
|
block_tables = batch.block_tables_tensor
|
|
slots = batch.slots[batch.slot_indices]
|
|
input_lengths = batch.input_lengths_tensor
|
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
|
max_s = batch.max_current_length
|
|
lm_head_indices = batch.prefill_head_indices
|
|
|
|
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
|
if position_ids.dim() == 1 and batch.prefilling:
|
|
position_ids = self.model.get_position_ids(
|
|
input_ids, batch.image_grid_thw
|
|
)
|
|
batch.position_ids = position_ids
|
|
|
|
attention_mask = None
|
|
attention_mask_forward = None
|
|
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
|
attention_mask = self.model.get_attention_mask(
|
|
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
|
)
|
|
min_dtype = torch.finfo(self.dtype).min
|
|
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
|
input_ids.device
|
|
)
|
|
attention_mask = attention_mask.reshape(-1)
|
|
|
|
# Try to find an associated cuda graph
|
|
bs = input_ids.shape[0]
|
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
if sorted_padded_bs:
|
|
# Get associated cuda graph
|
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
else:
|
|
cuda_graph = None
|
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = block_tables_to_ragged(
|
|
block_tables=block_tables,
|
|
input_lengths=batch.input_lengths,
|
|
cache_lengths=batch.cache_lengths,
|
|
input_lengths_tensor=batch.input_lengths_tensor,
|
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
|
max_current_length=batch.max_current_length,
|
|
)
|
|
with self._forward_context(
|
|
block_tables=block_tables,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
input_lengths_tensor=input_lengths,
|
|
cache_lengths_tensor=cache_lengths_tensor,
|
|
attention_mask=attention_mask,
|
|
):
|
|
seqlen = Seqlen(
|
|
input_lengths=input_lengths,
|
|
cache_lengths=cache_lengths_tensor,
|
|
cu_seqlen_q=cu_seqlen_prefill,
|
|
max_q=batch.max_input_length,
|
|
max_k=batch.max_current_length,
|
|
)
|
|
logits, speculative_logits = self.model.forward(
|
|
inputs_embeds=inputs_embeds,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
seqlen=seqlen,
|
|
max_s=max_s,
|
|
prefill_cache_indices=batch.prefill_cache_indices,
|
|
lm_head_indices=lm_head_indices,
|
|
attention_mask=attention_mask_forward,
|
|
)
|
|
if batch.prefill_cache_indices is not None:
|
|
batch.prefill_cache_indices = None
|
|
batch.image_grid_thw = None
|
|
batch.free_encoder_cache()
|
|
return logits, speculative_logits
|
|
|
|
# Copy inputs to the static inputs of the cuda graph
|
|
# Static inputs are potentially padded
|
|
cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
|
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
if ATTENTION == "flashinfer":
|
|
block_tables = block_tables_to_ragged(
|
|
block_tables=block_tables,
|
|
input_lengths=batch.input_lengths,
|
|
cache_lengths=batch.cache_lengths,
|
|
input_lengths_tensor=batch.input_lengths_tensor,
|
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
|
max_current_length=batch.max_current_length,
|
|
)
|
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
|
else:
|
|
cuda_graph["block_tables"][
|
|
: block_tables.shape[0], : block_tables.shape[1]
|
|
] = block_tables
|
|
|
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
|
# so it doesn't matter if we override it with bogus values.
|
|
cuda_graph["slots"].fill_(0)
|
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
cuda_graph["input_lengths"].zero_()
|
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
cuda_graph["cache_lengths"].zero_()
|
|
cuda_graph["cache_lengths"][
|
|
: cache_lengths_tensor.shape[0]
|
|
] = cache_lengths_tensor
|
|
|
|
with self._forward_context(
|
|
block_tables=cuda_graph["block_tables"],
|
|
cu_seqlen_prefill=None,
|
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
|
state=cuda_graph["state"],
|
|
):
|
|
# Replay the graph
|
|
cuda_graph["graph"].replay()
|
|
|
|
# Slice output to the correct shape
|
|
speculative_logits = (
|
|
cuda_graph["speculative_logits"][:bs]
|
|
if cuda_graph["speculative_logits"] is not None
|
|
else None
|
|
)
|
|
logits = cuda_graph["logits"][:bs]
|
|
|
|
batch.free_encoder_cache()
|
|
return logits, speculative_logits
|