add improvements

This commit is contained in:
Mohit Sharma 2025-04-21 15:25:03 +00:00
parent 7237e8e6bf
commit be8e60a918
12 changed files with 654 additions and 249 deletions

View File

@ -39,7 +39,7 @@ httpcore==1.0.7
# via httpx
httpx==0.28.1
# via openai
huggingface-hub==0.29.3
huggingface-hub==0.30.1
# via
# text-generation-integration-tests (pyproject.toml)
# text-generation

View File

@ -128,9 +128,6 @@ try:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
@ -1196,6 +1193,7 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
@ -1208,6 +1206,7 @@ def get_model(
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
support_chunking=False,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
@ -1583,6 +1582,7 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
# TODO: Uncomment when transformers is refactored and cross attn is added
# elif FLASH_TRANSFORMERS_BACKEND:
@ -1676,7 +1676,6 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
@ -1689,7 +1688,6 @@ def get_model(
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
batch_class=PaliGemmaBatch,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))

View File

@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_attention_mask(
self,
@ -762,6 +763,38 @@ class Gemma3ForConditionalGeneration(nn.Module):
else:
return torch.where(full_attention_mask, 0, min_dtype).to(device)
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# Replace the image token embeddings with the vision features
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = vision_embeds.view(
-1, vision_embeds.shape[-1]
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -781,26 +814,17 @@ class Gemma3ForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = image_features.view(
-1, image_features.shape[-1]
)
if torch.any(image_token_mask):
attention_mask = self.get_attention_mask(
input_ids,
cu_seqlen_prefill,

View File

@ -62,6 +62,37 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(
image_features.shape[0], image_features.shape[1], -1
)
return image_features
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1])
return inputs_embeds
def forward(
self,
@ -81,27 +112,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -476,38 +476,18 @@ class Idefics3ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
def get_vision_embeds(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.BoolTensor,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
@ -561,10 +541,54 @@ class Idefics3ForConditionalGeneration(nn.Module):
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: torch.BoolTensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -163,27 +163,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
)
return inputs_embeds
def forward(
def get_vision_embeds(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
pixel_values: torch.FloatTensor,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
@ -218,8 +203,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
@ -256,9 +240,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
@ -266,11 +248,51 @@ class LlavaNextForConditionalGeneration(nn.Module):
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -959,6 +959,7 @@ class MllamaForConditionalGeneration(nn.Module):
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds=None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)

View File

@ -922,6 +922,29 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
)
return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -943,17 +966,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -500,6 +500,29 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -520,17 +543,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -29,6 +29,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(self):
super(VlmCausalLMBatch, self).prepare_for_prefill()
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
@ -196,6 +199,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM):
def get_input_embeddings(self, batch):
batch.inputs_embeds = None
def forward(
self,
batch: MllamaCausalLMBatch,

View File

@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
processor_kwargs=None,
kv_cache_dtype: Optional[torch.dtype] = None,
batch_class=VlmCausalLMBatch,
support_chunking: bool = True,
):
self.batch_class = batch_class
self.quantize = quantize
@ -304,7 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
device=device,
rank=rank,
world_size=world_size,
support_chunking=True,
support_chunking=support_chunking,
)
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
@ -339,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
trust_remote_code: bool = False,
batch_class: Optional[type] = VlmCausalLMBatch,
processor_kwargs: Optional[dict] = None,
support_chunking: bool = True,
):
return cls(
model_id=model_id,
@ -350,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
trust_remote_code=trust_remote_code,
batch_class=batch_class,
processor_kwargs=processor_kwargs,
support_chunking=support_chunking,
)
def _model_forward(

View File

@ -13,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
n_rows = image_input[image_id]["rows"][0][0]
n_cols = image_input[image_id]["cols"][0][0]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
)
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
height, width = image_input[image_id]["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
log_master(
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
@ -344,8 +344,155 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
max_length = 0
vocab = tokenizer.get_vocab()
config.image_token_index = (
config.image_token_index
if hasattr(config, "image_token_index")
else config.image_token_id
)
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for i, r in enumerate(requests):
text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
text_parts.append(chunk.text)
continue
if chunk_type != "image":
raise RuntimeError(f"Invalid chunk type {chunk_type}")
img = Image.open(BytesIO(chunk.image.data))
if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if config.model_type in {"paligemma"}:
img = img.convert("RGB")
if config.model_type not in {"llava_next", "gemma3", "llama4"}:
img = [img]
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, id_token_str = image_text_replacement(
processor, image_input, config, 0
)
text_parts.append(img_text)
image_texts.append([image_id, id_token_str, img_text])
image_id += 1
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
is_embed = torch.tensor(tokens) == config.image_token_index
num_placeholder_tokens = is_embed.sum().item()
length = len(tokens)
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod
def batch_tokenized_inputs2(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# sizes to insert correct number of image tokens.
kwargs = {}
if (
@ -374,21 +521,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type in {"llava_next", "gemma3", "llama4"}:
image = image
elif config.model_type in {"paligemma"}:
image = image.convert("RGB")
else:
image = [image]
pixel_values = processor.image_processor(
image_input = processor.image_processor(
[image], return_tensors="pt", **kwargs
)
image_inputs.append(pixel_values)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs
# pixel_values = processor.image_processor(
# all_images, return_tensors="pt", **kwargs
# )
batch_image_positions = []
batch_tokenized_inputs = []
@ -554,29 +700,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if image_id not in self.encoder_cache[i]:
self.pixel_values.append((i, image_position, image_inputs))
# scheduled_image_pixel_values.append(image_inputs)
self.image_inputs[i][j] = None
# if self.has_image and len(scheduled_image_pixel_values):
# self.pixel_values = [
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
# ]
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
# self.pixel_attention_mask = [
# d["pixel_attention_mask"].to(device)
# for d in scheduled_image_pixel_values
# ]
# if "image_sizes" in scheduled_image_pixel_values[0]:
# self.image_sizes = [
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
# ]
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
# self.image_grid_thw = [
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
# ]
if not self.has_image:
self.pixel_values = None
self.pixel_attention_mask = None
@ -637,12 +762,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
mm_embeds_item = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
if mm_embeds_item is not None:
mm_embeds.append(mm_embeds_item)
if len(mm_embeds) == 0:
return None
return torch.cat(mm_embeds, dim=0).to(device)
def free_encoder_cache(self):
@ -662,6 +796,7 @@ class VlmCausalLM(FlashCausalLM):
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
support_chunking: bool = True,
**kwargs,
):
if PREFIX_CACHING:
@ -679,8 +814,7 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
support_chunking=support_chunking,
**kwargs,
)
@ -688,6 +822,153 @@ class VlmCausalLM(FlashCausalLM):
def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
input_lengths = [max_s] * bs
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
input_embeds = torch.zeros(
(bs, self.model.config.text_config.hidden_size),
device=self.device,
dtype=self.dtype,
)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
config = getattr(self.model, "config", None)
rope_scaling = getattr(config, "rope_scaling", None) if config else None
if ( # mrope have position_ids per section, if so repeat n times
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
):
n_sections = len(self.model.config.rope_scaling["mrope_section"])
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
)
cache_lengths_tensor = torch.zeros(
bs, dtype=torch.int32, device=self.device
)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
slots = self.cuda_graphs[max_bs]["slots"][:bs]
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables,
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
else:
state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"input_embeds": input_embeds,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"cache_lengths": cache_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor,
state=state,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward(
input_ids=input_ids,
inputs_embeds=input_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
del seqlen
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=input_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
@ -901,6 +1182,7 @@ class VlmCausalLM(FlashCausalLM):
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["input_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(