Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-08 20:03:28 -07:00
parent 839477670a
commit b09d4cc142
11 changed files with 903 additions and 521 deletions

View File

@ -83,9 +83,6 @@ try:
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM, FlashGPTNeoXForCausalLM,
) )
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration, PaliGemmaForConditionalGeneration,
) )
@ -153,7 +150,6 @@ if FLASH_ATTENTION:
) )
VLM_BATCH_TYPES = { VLM_BATCH_TYPES = {
PaliGemmaBatch,
FlashVlmCausalLMBatch, FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch, FlashMllamaCausalLMBatch,
} }
@ -635,6 +631,7 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
) )
elif model_type == BAICHUAN: elif model_type == BAICHUAN:
return FlashCausalLM( return FlashCausalLM(
@ -784,6 +781,8 @@ def get_model(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
) )
elif model_type == QWEN2_5_VL: elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM( return FlashVlmCausalLM(
@ -799,6 +798,8 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig, config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor, processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
) )
elif model_type == QWEN3: elif model_type == QWEN3:
return FlashCausalLM( return FlashCausalLM(
@ -824,6 +825,7 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
) )
elif model_type == IDEFICS2: elif model_type == IDEFICS2:
return FlashVlmCausalLM( return FlashVlmCausalLM(
@ -868,7 +870,6 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
) )
elif model_type == LLAVA_NEXT: elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM( return FlashVlmCausalLM(

View File

@ -163,25 +163,13 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
) )
return inputs_embeds return inputs_embeds
def forward( def get_vision_embeds(
self, self,
input_ids: torch.Tensor, pixel_values: torch.FloatTensor,
position_ids: torch.Tensor, pixel_attention_mask: Optional[torch.FloatTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor], image_sizes: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings # 1. Extract the input embeddings
@ -216,8 +204,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = ( height = width = (
self.config.vision_config.image_size self.config.vision_config.image_size // self.config.vision_config.patch_size
// self.config.vision_config.patch_size
) )
new_image_features = [] new_image_features = []
@ -254,9 +241,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
dim=-1, dim=-1,
) )
image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat( image_feature = torch.cat((base_image_feature, image_feature), dim=0)
(base_image_feature, image_feature), dim=0
)
else: else:
image_feature = image_feature[0] image_feature = image_feature[0]
image_feature = torch.cat( image_feature = torch.cat(
@ -264,10 +249,38 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
) )
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0) image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, vision_embeds
) )
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = ( self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
) )
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
position_ids += 1 position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -734,33 +734,20 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds return inputs_embeds
def forward( def get_vision_embeds(
self, self,
input_ids: torch.Tensor, pixel_values: torch.FloatTensor,
position_ids: torch.Tensor, pixel_attention_mask: Optional[torch.FloatTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids) assert pixel_values is not None
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = [] all_states = []
all_pixel_values = pixel_values all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask all_pixel_mask = pixel_attention_mask
for i in range(batch_size): for i in range(batch_size):
pixel_values = all_pixel_values.to( pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1] pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
@ -813,9 +800,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
conv_kernel, conv_kernel,
stride=patch_size, stride=patch_size,
).squeeze(1) ).squeeze(1)
patch_attention_mask = torch.eq( patch_attention_mask = torch.gt(patches_subgrid, 0)
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder # Get sequence from the vision encoder
image_hidden_states = self.vision_model( image_hidden_states = self.vision_model(
@ -830,12 +815,36 @@ class Idefics2ForConditionalGeneration(nn.Module):
) )
all_states.append(image_hidden_states) all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0) image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states input_ids, inputs_embeds, vision_embeds
) )
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -477,36 +477,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds return inputs_embeds
def forward( def get_vision_embeds(
self, self,
input_ids: torch.Tensor, pixel_values: torch.FloatTensor,
position_ids: torch.Tensor, pixel_attention_mask: Optional[torch.FloatTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = [] all_states = []
all_pixel_values = pixel_values all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask all_pixel_mask = pixel_attention_mask
for i in range(batch_size): for i in range(batch_size):
pixel_values = all_pixel_values.to( pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1] pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
@ -538,6 +521,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
].contiguous() ].contiguous()
patch_size = self.config.vision_config.patch_size patch_size = self.config.vision_config.patch_size
""" """
patches_subgrid = pixel_attention_mask.unfold( patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size dimension=1, size=patch_size, step=patch_size
@ -558,9 +542,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
conv_kernel, conv_kernel,
stride=patch_size, stride=patch_size,
).squeeze(1) ).squeeze(1)
patch_attention_mask = torch.eq( patch_attention_mask = torch.gt(patches_subgrid, 0)
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder # Get sequence from the vision encoder
image_hidden_states = self.vision_model( image_hidden_states = self.vision_model(
@ -576,10 +558,37 @@ class Idefics3ForConditionalGeneration(nn.Module):
all_states.append(image_hidden_states) all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0) image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features( return image_hidden_states.view(-1, image_hidden_states.shape[-1])
input_ids, inputs_embeds, image_hidden_states
)
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -900,9 +900,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
) )
return position_ids return position_ids
def forward( def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -910,26 +934,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.BoolTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
): ):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -474,9 +474,33 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
return position_ids return position_ids
def forward( def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -484,26 +508,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.BoolTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
): ):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -1457,7 +1457,7 @@ class FlashCausalLM(Model):
if head_size is None: if head_size is None:
# Some models use GQA and different sizes for o_proj # Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that. # and q_proj, that allows for that.
if hasattr(config, "head_dim"): if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim self.head_size = config.head_dim
else: else:
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
@ -2263,6 +2263,8 @@ class FlashCausalLM(Model):
batch.prepare_for_decode( batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx self.dtype, self.use_contiguous_pa, self.bucketing_ctx
) )
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta adapter_meta = batch.adapter_meta

View File

@ -1,7 +1,7 @@
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -119,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str: def image_text_replacement(processor, image_input, config) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
image_seq_len = 64 image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting: if processor.image_processor.do_image_splitting:
image_str *= 5 image_str *= 5
return image_str return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3": if config.model_type == "idefics3":
# TODO: implement this in a more general way # TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id] n_rows = image_input["rows"][0][0]
n_cols = image_input["cols"][0][image_id] n_cols = image_input["cols"][0][0]
image_seq_len = int( image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2) / (config.scale_factor**2)
@ -142,41 +142,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_token=IDEFICS3_IMAGE_TOKEN, image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
) )
return image_str return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
log_master( log_master(
logger.info, logger.info,
f"Found {num_features} features in image of resolution {height}x{width}", f"Found {num_features} features in image of resolution {height}x{width}",
) )
return "<image>" * num_features return "<image>" * num_features, "<image>"
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl": elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl": elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "gemma3": elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture # TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens # and calculating the number of image tokens
num_pads = 256 num_pads = 256
padding = "<image_soft_token>" * num_pads padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n" return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
elif config.model_type == "llama4": elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][image_id] aspect_ratios = image_input["aspect_ratios"][0]
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:] image_height, image_width = image_input["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int( num_patches_per_chunk = int(
(image_height // patch_size) (image_height // patch_size)
@ -187,7 +187,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
aspect_ratios, num_patches_per_chunk aspect_ratios, num_patches_per_chunk
) )
return tokens_for_this_image return tokens_for_this_image, "<|image_start|>"
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -200,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "<bos>" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if model_type == "paligemma":
img = img.convert("RGB")
if model_type not in {"llava_next", "gemma3", "llama4"}:
# TODO: check if this is needed
img = [img]
return img
def get_unpadded_features( def get_unpadded_features(
original_height: int, original_height: int,
original_width: int, original_width: int,
@ -254,66 +275,115 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
def scatter_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
if is_embed is None:
return embeds
sel = embeds[is_embed]
return sel if sel.numel() else None
@dataclass
class ImagePositions:
offset: int
length: int
id: int
num_placeholder_tokens: int
is_embed: Optional[torch.Tensor] = None
class FlashVlmCausalLMBatch(FlashCausalLMBatch): class FlashVlmCausalLMBatch(FlashCausalLMBatch):
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
image_positions: Optional[List[List[ImagePositions]]]
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor] image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
inputs_embeds: Optional[torch.Tensor] = None
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0): def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.image_inputs = []
batch.image_positions = []
batch.encoder_cache = []
for b in batches:
if b.image_inputs is not None:
batch.image_inputs.extend(b.image_inputs)
else:
batch.image_inputs.append(None)
if b.image_positions is not None:
batch.image_positions.extend(b.image_positions)
else:
batch.image_positions.append(None)
if b.encoder_cache is not None:
batch.encoder_cache.extend(b.encoder_cache)
else:
batch.encoder_cache.append(None)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.inputs_embeds = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
image_inputs = []
image_positions = []
encoder_cache = []
for request_id in request_ids:
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
encoder_cache.append(self.encoder_cache[idx])
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.inputs_embeds = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = encoder_cache
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch return batch
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
): ):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
elif config.model_type == "llama4":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {} kwargs = {}
if ( if (
hasattr(processor, "image_processor_class") hasattr(processor, "image_processor_class")
@ -321,38 +391,143 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
): ):
kwargs["return_row_col_info"] = True kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
batch_tokenized_inputs = []
max_length = 0 max_length = 0
image_id = 0 vocab = tokenizer.get_vocab()
if not hasattr(config, "image_token_index"):
config.image_token_index = config.image_token_id
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for r in requests: for r in requests:
full_text = "" text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks: for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk") chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text": if chunk_type == "text":
full_text += chunk.text text = preprocess_text(config, chunk.text)
text_parts.append(text)
elif chunk_type == "image": elif chunk_type == "image":
full_text += image_text_replacement( img = Image.open(BytesIO(chunk.image.data))
processor, image_inputs, config, image_id img = preprocess_image(config, img)
)
image_id += 1
full_text = image_text_replacement_fixup(config, full_text) image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, img_start_token_str = image_text_replacement(
processor, image_input, config
)
text_parts.append(img_text)
image_texts.append([image_id, img_start_token_str, img_text])
image_id += 1
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer( input_ids = tokenizer(
full_text, full_text,
truncation=True, truncation=True,
max_length=r.truncate, max_length=r.truncate,
add_special_tokens=r.add_special_tokens, add_special_tokens=(
r.add_special_tokens if config.model_type != "paligemma" else False
),
)["input_ids"] )["input_ids"]
max_length = max(max_length, len(input_ids)) max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
num_tokens = input_ids_t.numel()
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
][0]
length = tokens.numel()
assert (
length <= num_tokens
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
assert torch.equal(
input_ids_t[index : index + length], tokens
), "Image tokens not found in input_ids"
is_embed = tokens == config.image_token_index
num_placeholder_tokens = int(is_embed.sum())
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod @classmethod
def from_pb_processor( def from_pb_processor(
@ -364,33 +539,162 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashVlmCausalLMBatch": ) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs, image_positions = (
pb.requests, tokenizer, processor, config cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
) )
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: batch.image_inputs = image_inputs
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.image_positions = image_positions
if "pixel_attention_mask" in image_inputs: batch.encoder_cache = [{} for _ in range(len(pb.requests))]
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( if len(image_inputs):
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
def prepare_for_prefill(self):
super().prepare_for_prefill()
self.has_image_inputs = False
self.cache_entries_to_free = []
self.pixel_values = []
assert (
len(self.cache_lengths)
== len(self.input_lengths)
== len(self.prefilling_mask)
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
for i, (
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image_inputs = True
if image_position.id not in self.encoder_cache[i]:
image_inputs = self.image_inputs[i][image_position.id]
self.pixel_values.append((i, image_position.id, image_inputs))
# Remove the image from the image_inputs
self.image_inputs[i][image_position.id] = None
if not self.has_image_inputs:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
else:
image_grid_thw_list = [
x[2]["image_grid_thw"]
for x in self.pixel_values
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
self.input_ids.device
)
else:
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
encoder_outputs, img_pos.is_embed
)
def gather_vision_embeds(self):
device = self.input_ids.device
chunks = []
for (
i,
cache_length,
input_length,
request_prefilling,
) in zip(
range(len(self.requests)),
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
assert (
image_position.id in self.encoder_cache[i]
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_position.id]
is_embed = image_position.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
embeds = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
if embeds is not None:
chunks.append(embeds)
if end_idx == length:
self.cache_entries_to_free.append((i, image_position.id))
self.image_positions[i][image_position.id] = None
if len(chunks) == 0:
return None
return torch.cat(chunks, dim=0).to(device)
def free_encoder_cache(self):
for i, image_id in self.cache_entries_to_free:
self.encoder_cache[i].pop(image_id, None)
self.cache_entries_to_free = []
class FlashVlmCausalLM(FlashCausalLM): class FlashVlmCausalLM(FlashCausalLM):
def __init__( def __init__(
@ -402,6 +706,7 @@ class FlashVlmCausalLM(FlashCausalLM):
batch_class=FlashVlmCausalLMBatch, batch_class=FlashVlmCausalLMBatch,
revision, revision,
trust_remote_code: bool, trust_remote_code: bool,
support_chunking: bool = False,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING: if PREFIX_CACHING:
@ -419,8 +724,7 @@ class FlashVlmCausalLM(FlashCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet support_chunking=support_chunking,
support_chunking=False,
**kwargs, **kwargs,
) )
@ -471,9 +775,12 @@ class FlashVlmCausalLM(FlashCausalLM):
bucketing_ctx=None, bucketing_ctx=None,
) )
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids.to(self.device),
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
@ -481,10 +788,7 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None, lm_head_indices=None,
pixel_values=None, attention_mask=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
) )
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
@ -546,6 +850,84 @@ class FlashVlmCausalLM(FlashCausalLM):
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
) )
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: torch.Tensor,
image_sizes: torch.Tensor,
image_grid_thw: torch.Tensor,
):
embeds = self.model.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
return embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None,
):
return self.model.get_inputs_embeds(
input_ids=input_ids,
vision_embeds=vision_embeds,
)
def encode_images(self, batch):
if batch.pixel_values is not None:
device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
if "pixel_attention_mask" in image_input:
pixel_attention_mask = image_input["pixel_attention_mask"].to(
device
)
else:
pixel_attention_mask = None
if "image_sizes" in image_input:
image_sizes = image_input["image_sizes"].to(device)
else:
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"].to(device)
else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(
encoder_outputs,
request_id,
batch.image_positions[request_id][image_id],
)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
self.encode_images(batch)
vision_embeds = batch.gather_vision_embeds()
batch.has_image_inputs = False
else:
vision_embeds = None
inputs_embeds = self.get_inputs_embeds(
batch.input_ids, vision_embeds=vision_embeds
)
batch.inputs_embeds = inputs_embeds
def forward( def forward(
self, self,
batch: FlashVlmCausalLMBatch, batch: FlashVlmCausalLMBatch,
@ -593,6 +975,7 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids = batch.input_ids input_ids = batch.input_ids
inputs_embeds = batch.inputs_embeds
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache kv_cache = self.kv_cache
@ -609,6 +992,18 @@ class FlashVlmCausalLM(FlashCausalLM):
) )
batch.position_ids = position_ids batch.position_ids = position_ids
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
)
min_dtype = torch.finfo(self.dtype).min
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
input_ids.device
)
attention_mask = attention_mask.reshape(-1)
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
@ -639,7 +1034,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths), input_lengths=_async_h2d_tensor_copy(input_lengths),
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache, kv_cache=kv_cache,
@ -647,18 +1042,11 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
pixel_values=batch.pixel_values, attention_mask=attention_mask_forward,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
**kwargs, **kwargs,
) )
if batch.pixel_values is not None: if batch.prefill_cache_indices is not None:
batch.pixel_values = None batch.prefill_cache_indices = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits return logits, speculative_logits

View File

@ -49,7 +49,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0): def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super().concatenate(batches, padded_total_bs) batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
@ -228,6 +228,10 @@ def generate_cross_attention_states(
class FlashMllamaCausalLM(FlashVlmCausalLM): class FlashMllamaCausalLM(FlashVlmCausalLM):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def warmup_decode( def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
): ):

View File

@ -1,71 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
image_text_replacement,
)
from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(FlashVlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs