add improvements

This commit is contained in:
Mohit Sharma 2025-04-21 15:25:03 +00:00
parent 7237e8e6bf
commit be8e60a918
12 changed files with 654 additions and 249 deletions

View File

@ -39,7 +39,7 @@ httpcore==1.0.7
# via httpx # via httpx
httpx==0.28.1 httpx==0.28.1
# via openai # via openai
huggingface-hub==0.29.3 huggingface-hub==0.30.1
# via # via
# text-generation-integration-tests (pyproject.toml) # text-generation-integration-tests (pyproject.toml)
# text-generation # text-generation

View File

@ -128,9 +128,6 @@ try:
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM, FlashGPTNeoXForCausalLM,
) )
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration, PaliGemmaForConditionalGeneration,
) )
@ -1196,6 +1193,7 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
) )
elif FLASH_TRANSFORMERS_BACKEND: elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Gemma3ForConditionalGeneration as Gemma3Model from transformers import Gemma3ForConditionalGeneration as Gemma3Model
@ -1208,6 +1206,7 @@ def get_model(
speculator=speculator, speculator=speculator,
dtype=torch.bfloat16, dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
support_chunking=False,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
@ -1583,6 +1582,7 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
) )
# TODO: Uncomment when transformers is refactored and cross attn is added # TODO: Uncomment when transformers is refactored and cross attn is added
# elif FLASH_TRANSFORMERS_BACKEND: # elif FLASH_TRANSFORMERS_BACKEND:
@ -1676,7 +1676,6 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
) )
elif FLASH_TRANSFORMERS_BACKEND: elif FLASH_TRANSFORMERS_BACKEND:
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
@ -1689,7 +1688,6 @@ def get_model(
speculator=speculator, speculator=speculator,
dtype=torch.bfloat16, dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=PaliGemmaBatch,
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))

View File

@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
self.pad_token_id = ( self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
self.dtype = weights.dtype
def get_attention_mask( def get_attention_mask(
self, self,
@ -762,6 +763,38 @@ class Gemma3ForConditionalGeneration(nn.Module):
else: else:
return torch.where(full_attention_mask, 0, min_dtype).to(device) return torch.where(full_attention_mask, 0, min_dtype).to(device)
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# Replace the image token embeddings with the vision features
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = vision_embeds.view(
-1, vision_embeds.shape[-1]
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -781,26 +814,17 @@ class Gemma3ForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
max_s += 1 max_s += 1
position_ids += 1 position_ids += 1
if pixel_values is not None: image_token_mask = (input_ids == self.config.image_token_index).to(
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) input_ids.device
image_outputs = self.vision_model(pixel_values) )
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_token_mask = (input_ids == self.config.image_token_index).to( if torch.any(image_token_mask):
input_ids.device
)
inputs_embeds[image_token_mask] = image_features.view(
-1, image_features.shape[-1]
)
attention_mask = self.get_attention_mask( attention_mask = self.get_attention_mask(
input_ids, input_ids,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -62,6 +62,37 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = ( self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(
image_features.shape[0], image_features.shape[1], -1
)
return image_features
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1])
return inputs_embeds
def forward( def forward(
self, self,
@ -81,27 +112,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
max_s += 1 max_s += 1
position_ids += 1 position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -476,6 +476,96 @@ class Idefics3ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.BoolTensor,
**kwargs,
):
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: torch.BoolTensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -497,74 +587,8 @@ class Idefics3ForConditionalGeneration(nn.Module):
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -163,6 +163,116 @@ class LlavaNextForConditionalGeneration(nn.Module):
) )
return inputs_embeds return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_sizes: Optional[torch.LongTensor] = None,
**kwargs,
):
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -181,96 +291,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -959,6 +959,7 @@ class MllamaForConditionalGeneration(nn.Module):
# XXX: Putting these as optional so that the cuda warmup calls can go through. # XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
inputs_embeds=None,
): ):
if cross_attention_states is not None: if cross_attention_states is not None:
seqlen_q = len(image_indices) seqlen_q = len(image_indices)

View File

@ -922,6 +922,29 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
) )
return position_ids return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -943,17 +966,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -500,6 +500,29 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
return position_ids return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -520,17 +543,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -29,6 +29,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
aspect_ratio_mask: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(self):
super(VlmCausalLMBatch, self).prepare_for_prefill()
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches): def concatenate(cls, batches):
@ -196,6 +199,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM): class MllamaCausalLM(VlmCausalLM):
def get_input_embeddings(self, batch):
batch.inputs_embeds = None
def forward( def forward(
self, self,
batch: MllamaCausalLMBatch, batch: MllamaCausalLMBatch,

View File

@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
processor_kwargs=None, processor_kwargs=None,
kv_cache_dtype: Optional[torch.dtype] = None, kv_cache_dtype: Optional[torch.dtype] = None,
batch_class=VlmCausalLMBatch, batch_class=VlmCausalLMBatch,
support_chunking: bool = True,
): ):
self.batch_class = batch_class self.batch_class = batch_class
self.quantize = quantize self.quantize = quantize
@ -304,7 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
support_chunking=True, support_chunking=support_chunking,
) )
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
@ -339,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
batch_class: Optional[type] = VlmCausalLMBatch, batch_class: Optional[type] = VlmCausalLMBatch,
processor_kwargs: Optional[dict] = None, processor_kwargs: Optional[dict] = None,
support_chunking: bool = True,
): ):
return cls( return cls(
model_id=model_id, model_id=model_id,
@ -350,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=batch_class, batch_class=batch_class,
processor_kwargs=processor_kwargs, processor_kwargs=processor_kwargs,
support_chunking=support_chunking,
) )
def _model_forward( def _model_forward(

View File

@ -13,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL
from loguru import logger from loguru import logger
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
return image_str, IDEFICS2_FAKE_TOKEN return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3": if config.model_type == "idefics3":
# TODO: implement this in a more general way # TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id] n_rows = image_input[image_id]["rows"][0][0]
n_cols = image_input["cols"][0][image_id] n_cols = image_input[image_id]["cols"][0][0]
image_seq_len = int( image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2) / (config.scale_factor**2)
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
) )
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input[image_id]["image_sizes"][0]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
log_master( log_master(
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens, "<image>" return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl": elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl": elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
@ -344,8 +344,155 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
): ):
# Process images first. We need all of them so that the processor kwargs = {}
# can make the image splits the same size. And we need the final if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
max_length = 0
vocab = tokenizer.get_vocab()
config.image_token_index = (
config.image_token_index
if hasattr(config, "image_token_index")
else config.image_token_id
)
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for i, r in enumerate(requests):
text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
text_parts.append(chunk.text)
continue
if chunk_type != "image":
raise RuntimeError(f"Invalid chunk type {chunk_type}")
img = Image.open(BytesIO(chunk.image.data))
if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if config.model_type in {"paligemma"}:
img = img.convert("RGB")
if config.model_type not in {"llava_next", "gemma3", "llama4"}:
img = [img]
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, id_token_str = image_text_replacement(
processor, image_input, config, 0
)
text_parts.append(img_text)
image_texts.append([image_id, id_token_str, img_text])
image_id += 1
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
is_embed = torch.tensor(tokens) == config.image_token_index
num_placeholder_tokens = is_embed.sum().item()
length = len(tokens)
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod
def batch_tokenized_inputs2(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# sizes to insert correct number of image tokens. # sizes to insert correct number of image tokens.
kwargs = {} kwargs = {}
if ( if (
@ -374,21 +521,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type in {"llava_next", "gemma3", "llama4"}: if config.model_type in {"llava_next", "gemma3", "llama4"}:
image = image image = image
elif config.model_type in {"paligemma"}:
image = image.convert("RGB")
else: else:
image = [image] image = [image]
pixel_values = processor.image_processor( image_input = processor.image_processor(
[image], return_tensors="pt", **kwargs [image], return_tensors="pt", **kwargs
) )
image_inputs.append(pixel_values) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
if len(image_inputs) > 0: if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs batch_image_inputs[i] = image_inputs
# pixel_values = processor.image_processor(
# all_images, return_tensors="pt", **kwargs
# )
batch_image_positions = [] batch_image_positions = []
batch_tokenized_inputs = [] batch_tokenized_inputs = []
@ -554,29 +700,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if image_id not in self.encoder_cache[i]: if image_id not in self.encoder_cache[i]:
self.pixel_values.append((i, image_position, image_inputs)) self.pixel_values.append((i, image_position, image_inputs))
# scheduled_image_pixel_values.append(image_inputs)
self.image_inputs[i][j] = None self.image_inputs[i][j] = None
# if self.has_image and len(scheduled_image_pixel_values):
# self.pixel_values = [
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
# ]
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
# self.pixel_attention_mask = [
# d["pixel_attention_mask"].to(device)
# for d in scheduled_image_pixel_values
# ]
# if "image_sizes" in scheduled_image_pixel_values[0]:
# self.image_sizes = [
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
# ]
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
# self.image_grid_thw = [
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
# ]
if not self.has_image: if not self.has_image:
self.pixel_values = None self.pixel_values = None
self.pixel_attention_mask = None self.pixel_attention_mask = None
@ -637,12 +762,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if is_embed is not None: if is_embed is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
mm_embeds_item = gather_image_embeds( mm_embeds_item = gather_image_embeds(
encoder_output[start_idx:end_idx], encoder_output[start_idx:end_idx],
is_embed=is_embed, is_embed=is_embed,
) )
mm_embeds.append(mm_embeds_item) if mm_embeds_item is not None:
mm_embeds.append(mm_embeds_item)
if len(mm_embeds) == 0:
return None
return torch.cat(mm_embeds, dim=0).to(device) return torch.cat(mm_embeds, dim=0).to(device)
def free_encoder_cache(self): def free_encoder_cache(self):
@ -662,6 +796,7 @@ class VlmCausalLM(FlashCausalLM):
batch_class=VlmCausalLMBatch, batch_class=VlmCausalLMBatch,
revision, revision,
trust_remote_code: bool, trust_remote_code: bool,
support_chunking: bool = True,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING: if PREFIX_CACHING:
@ -679,8 +814,7 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet support_chunking=support_chunking,
support_chunking=False,
**kwargs, **kwargs,
) )
@ -688,6 +822,153 @@ class VlmCausalLM(FlashCausalLM):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class return self.batch_class
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
input_lengths = [max_s] * bs
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
input_embeds = torch.zeros(
(bs, self.model.config.text_config.hidden_size),
device=self.device,
dtype=self.dtype,
)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
config = getattr(self.model, "config", None)
rope_scaling = getattr(config, "rope_scaling", None) if config else None
if ( # mrope have position_ids per section, if so repeat n times
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
):
n_sections = len(self.model.config.rope_scaling["mrope_section"])
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
)
cache_lengths_tensor = torch.zeros(
bs, dtype=torch.int32, device=self.device
)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
slots = self.cuda_graphs[max_bs]["slots"][:bs]
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables,
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
else:
state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"input_embeds": input_embeds,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"cache_lengths": cache_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor,
state=state,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward(
input_ids=input_ids,
inputs_embeds=input_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
del seqlen
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=input_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def get_vision_embeds( def get_vision_embeds(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
@ -901,6 +1182,7 @@ class VlmCausalLM(FlashCausalLM):
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["input_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(