diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 93a2f8bf..0a211ec3 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -128,9 +128,6 @@ try: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, ) - from text_generation_server.models.pali_gemma import ( - PaliGemmaBatch, - ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) @@ -1196,6 +1193,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) elif FLASH_TRANSFORMERS_BACKEND: from transformers import Gemma3ForConditionalGeneration as Gemma3Model @@ -1208,6 +1206,7 @@ def get_model( speculator=speculator, dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + support_chunking=False, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3")) @@ -1523,6 +1522,8 @@ def get_model( kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, ) # TODO: Uncomment when transformers is refactored # elif FLASH_TRANSFORMERS_BACKEND: @@ -1554,6 +1555,8 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, ) # TODO: Uncomment when transformers is refactored # elif FLASH_TRANSFORMERS_BACKEND: @@ -1583,6 +1586,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) # TODO: Uncomment when transformers is refactored and cross attn is added # elif FLASH_TRANSFORMERS_BACKEND: @@ -1676,7 +1680,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, ) elif FLASH_TRANSFORMERS_BACKEND: from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel @@ -1689,7 +1692,6 @@ def get_model( speculator=speculator, dtype=torch.bfloat16, trust_remote_code=trust_remote_code, - batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma")) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 58afd643..b0047f1e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype def get_attention_mask( self, @@ -762,9 +763,42 @@ class Gemma3ForConditionalGeneration(nn.Module): else: return torch.where(full_attention_mask, 0, min_dtype).to(device) - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_model(pixel_values) + vision_outputs = self.post_vision_model_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multimodal_projector(vision_outputs) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # Replace the image token embeddings with the vision features + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + inputs_embeds[image_token_mask] = vision_embeds.view( + -1, vision_embeds.shape[-1] + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -777,35 +811,12 @@ class Gemma3ForConditionalGeneration(nn.Module): pixel_values: torch.FloatTensor = None, # Unused here attention_mask: Optional[torch.BoolTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_model(pixel_values) - vision_outputs = self.post_vision_model_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multimodal_projector(vision_outputs) - - image_token_mask = (input_ids == self.config.image_token_index).to( - input_ids.device - ) - inputs_embeds[image_token_mask] = image_features.view( - -1, image_features.shape[-1] - ) - attention_mask = self.get_attention_mask( - input_ids, - cu_seqlen_prefill, - inputs_embeds.dtype, - ) # Use flash attention for text-only input # else: # if cu_seqlen_prefill is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 7ad294f4..c855745c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -116,11 +116,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index b1f89eff..41af2b9e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_tower(pixel_values) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + mask = input_ids == self.config.image_token_index + inputs_embeds[mask] = vision_embeds + + return inputs_embeds def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -75,33 +105,15 @@ class PaliGemmaForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, # Unused here - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_tower(pixel_values) - last_hidden_state = self.post_vision_tower_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multi_modal_projector(last_hidden_state) - - # mask where image or padding tokens - mask = input_ids == self.config.image_token_index - - # insert image features into input embeddings - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 923123d6..5c0d2fcc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -733,9 +733,93 @@ class Idefics2ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + assert pixel_values is not None + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -745,82 +829,10 @@ class Idefics2ForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 580398cb..6d303c2c 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -476,9 +476,92 @@ class Idefics3ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -488,83 +571,11 @@ class Idefics3ForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - ) - - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index df7366ea..56a9565b 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -163,9 +163,114 @@ class LlavaNextForConditionalGeneration(nn.Module): ) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + return image_features.view(-1, image_features.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -175,102 +280,10 @@ class LlavaNextForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, # Unused for this model - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width - ) - image_features = self.vision_tower(pixel_values) - - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state - - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." - ) - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 066de6a2..e2fc60b1 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -922,9 +922,32 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + inputs_embeds[input_ids == self.image_token_id] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -934,26 +957,11 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, # Unused in this model - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 26e6fede..75f718bd 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -500,9 +500,32 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + inputs_embeds[input_ids == self.image_token_id] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -512,25 +535,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + attention_mask=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a28ef381..207226ff 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1298,7 +1298,7 @@ class FlashCausalLM(Model): if head_size is None: # Some models use GQA and different sizes for o_proj # and q_proj, that allows for that. - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = config.hidden_size // config.num_attention_heads @@ -1896,6 +1896,9 @@ class FlashCausalLM(Model): if prefill: batch.prepare_for_prefill() + if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): + self.set_inputs_embeds(batch) + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index c268ff9a..af9a811c 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -29,10 +29,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + def prepare_for_prefill(self): + super(VlmCausalLMBatch, self).prepare_for_prefill() + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): - batch = super().concatenate(batches) + batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch.pixel_values = None batch.pixel_attention_mask = None @@ -196,6 +199,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLM(VlmCausalLM): + def set_inputs_embeds(self, batch): + # Set the input embeddings to None, as we are using the input_ids for the model + batch.inputs_embeds = None + + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + super(VlmCausalLM, self).cuda_graph_warmup(bs, max_s, max_bt) + def forward( self, batch: MllamaCausalLMBatch, diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py deleted file mode 100644 index fe75570e..00000000 --- a/server/text_generation_server/models/pali_gemma.py +++ /dev/null @@ -1,71 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import torch.distributed -from opentelemetry import trace -from typing import Iterable -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - image_text_replacement, -) - -from text_generation_server.pb.generate_pb2 import Request - -tracer = trace.get_tracer(__name__) - - -class PaliGemmaBatch(VlmCausalLMBatch): - @classmethod - def batch_tokenized_inputs( - cls, requests: Iterable[Request], tokenizer, processor, config - ): - batch_inputs = [] - image_inputs = [] - max_truncation = 0 - for r in requests: - full_text = "" - image_id = 0 - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += "" + chunk.text + "\n" - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO do_convert_RGB should be on by default ? - image = image.convert("RGB") - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement( - processor, image_input, config, image_id - ) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=False, - )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None - return batch_tokenized_inputs, image_inputs diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 280fa0bd..98644836 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): processor_kwargs=None, kv_cache_dtype: Optional[torch.dtype] = None, batch_class=VlmCausalLMBatch, + support_chunking: bool = True, ): self.batch_class = batch_class self.quantize = quantize @@ -304,6 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): device=device, rank=rank, world_size=world_size, + support_chunking=support_chunking, ) # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code @@ -338,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): trust_remote_code: bool = False, batch_class: Optional[type] = VlmCausalLMBatch, processor_kwargs: Optional[dict] = None, + support_chunking: bool = True, ): return cls( model_id=model_id, @@ -349,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): trust_remote_code=trust_remote_code, batch_class=batch_class, processor_kwargs=processor_kwargs, + support_chunking=support_chunking, ) def _model_forward( @@ -368,6 +372,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): image_grid_thw: Optional[torch.LongTensor] = None, pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, ): # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 @@ -377,9 +382,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, ) + inputs["input_ids"] = None + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( input_ids=inputs["input_ids"], + inputs_embeds=inputs_embeds.unsqueeze(0), position_ids=inputs["position_ids"], past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object @@ -568,3 +576,48 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): inputs["cache_position"] = position_ids inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) return inputs + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_features = self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.model.config.vision_config.vision_feature_layer, + vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.model.multi_modal_projector(vision_flat) + return projected_vision_flat + + def get_inputs_embeds(self, input_ids, vision_embeds=None): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + + if vision_embeds is not None: + original_inputs_embeds_shape = inputs_embeds.shape + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( + -1 + ) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + num_tokens_to_fill = final_mask_1d.sum() + + if num_tokens_to_fill != vision_embeds.size(0): + raise ValueError( + f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"but multi_modal_projector returned {vision_embeds.size(0)}" + ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand( + -1, inputs_embeds.size(-1) + ) + inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds) + inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2b1e01df..b76dbe68 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import torch from PIL import Image from io import BytesIO @@ -12,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor @@ -109,17 +110,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(processor, image_input, config, image_id: int) -> str: +def image_text_replacement(processor, image_input, config) -> str: if config.model_type == "idefics2": image_seq_len = 64 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" if processor.image_processor.do_image_splitting: image_str *= 5 - return image_str + return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way - n_rows = image_input["rows"][0][image_id] - n_cols = image_input["cols"][0][image_id] + n_rows = image_input["rows"][0][0] + n_cols = image_input["cols"][0][0] image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) @@ -132,41 +133,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_token=IDEFICS3_IMAGE_TOKEN, global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) - return image_str + return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][image_id] + height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config) log_master( logger.info, f"Found {num_features} features in image of resolution {height}x{width}", ) - return "" * num_features + return "" * num_features, "" elif config.model_type == "paligemma": - return "" * config.text_config.num_image_tokens + return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "gemma3": # TODO: get correct number of features via reviewing the Gemma3 architecture # and calculating the number of image tokens num_pads = 256 padding = "" * num_pads - return f"\n\n{padding}\n\n" + return f"\n\n{padding}\n\n", "" elif config.model_type == "llama4": patch_size = config.vision_config.patch_size pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) - aspect_ratios = image_input["aspect_ratios"][image_id] - image_height, image_width = image_input["pixel_values"][image_id].shape[-2:] + aspect_ratios = image_input["aspect_ratios"][0] + image_height, image_width = image_input["pixel_values"][0].shape[-2:] num_patches_per_chunk = int( (image_height // patch_size) @@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str aspect_ratios, num_patches_per_chunk ) - return tokens_for_this_image + return tokens_for_this_image, "<|image_start|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -190,6 +191,27 @@ def image_text_replacement_fixup(config, text: str) -> str: return text +def preprocess_text(config, text: str) -> str: + if config.model_type == "paligemma": + return "" + text + "\n" + return text + + +def preprocess_image(config, img): + model_type = config.model_type + + if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: + img = img.resize((img.width * 2, img.height * 2)) + if model_type == "paligemma": + img = img.convert("RGB") + + if model_type not in {"llava_next", "gemma3", "llama4"}: + # TODO: check if this is needed + img = [img] + + return img + + def get_unpadded_features( original_height: int, original_width: int, @@ -244,105 +266,263 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features +def scatter_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> torch.Tensor: + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + +def gather_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> Optional[torch.Tensor]: + if is_embed is None: + return embeds + sel = embeds[is_embed] + return sel if sel.numel() else None + + +@dataclass +class ImagePositions: + offset: int + length: int + id: int + num_placeholder_tokens: int + is_embed: Optional[torch.Tensor] = None + + class VlmCausalLMBatch(FlashCausalLMBatch): + image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]] + image_positions: Optional[List[List[ImagePositions]]] + encoder_cache: Optional[List[Dict[int, torch.Tensor]]] pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] image_grid_thw: Optional[torch.Tensor] + cache_entries_to_free: List[Tuple[int, int]] + has_image_inputs: bool = False + inputs_embeds: Optional[torch.Tensor] = None @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): batch = super(VlmCausalLMBatch, cls).concatenate(batches) + + batch.image_inputs = [] + batch.image_positions = [] + batch.encoder_cache = [] + for b in batches: + if b.image_inputs is not None: + batch.image_inputs.extend(b.image_inputs) + else: + batch.image_inputs.append(None) + if b.image_positions is not None: + batch.image_positions.extend(b.image_positions) + else: + batch.image_positions.append(None) + if b.encoder_cache is not None: + batch.encoder_cache.extend(b.encoder_cache) + else: + batch.encoder_cache.append(None) + batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] + return batch @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + + image_inputs = [] + image_positions = [] + encoder_cache = [] + + for request_id in request_ids: + idx = self.requests_idx_mapping[request_id] + image_inputs.append(self.image_inputs[idx]) + image_positions.append(self.image_positions[idx]) + encoder_cache.append(self.encoder_cache[idx]) + batch = super().filter(request_ids) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = encoder_cache + + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @classmethod def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): - # Process images first. We need all of them so that the processor - # can make the image splits the same size. And we need the final - # sizes to insert correct number of image tokens. - images = [] + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + max_length = 0 + vocab = tokenizer.get_vocab() + + if not hasattr(config, "image_token_index"): + config.image_token_index = config.image_token_id + + batch_tokenized_inputs: List[List[int]] = [] + batch_image_inputs: List[Optional[List[dict]]] = [] + batch_image_positions: List[Optional[List[ImagePositions]]] = [] + for r in requests: + text_parts = [] + image_inputs = [] + image_texts = [] + + image_id = 0 + for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": - pass + text = preprocess_text(config, chunk.text) + text_parts.append(text) elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the - # default warmup image is 20x20 - if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: - if image.width <= 20: - w = image.width * 2 - h = image.height * 2 - image = image.resize((w, h)) + img = Image.open(BytesIO(chunk.image.data)) + img = preprocess_image(config, img) - if config.model_type == "llava_next": - images.append(image) - elif config.model_type == "gemma3": - images.append(image) - elif config.model_type == "llama4": - images.append(image) - else: - images.append([image]) + image_input = processor.image_processor( + [img], return_tensors="pt", **kwargs + ) + image_inputs.append(image_input) + + img_text, img_start_token_str = image_text_replacement( + processor, image_input, config + ) + text_parts.append(img_text) + + image_texts.append([image_id, img_start_token_str, img_text]) + image_id += 1 else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - if images: - kwargs = {} - if ( - hasattr(processor, "image_processor_class") - and processor.image_processor_class == "Idefics3ImageProcessor" - ): - kwargs["return_row_col_info"] = True - - image_inputs = processor.image_processor( - images, return_tensors="pt", **kwargs - ) - else: - image_inputs = None - - batch_tokenized_inputs = [] - max_length = 0 - image_id = 0 - for r in requests: - full_text = "" - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += chunk.text - elif chunk_type == "image": - full_text += image_text_replacement( - processor, image_inputs, config, image_id - ) - image_id += 1 - # from pdb import set_trace; set_trace() - full_text = image_text_replacement_fixup(config, full_text) + full_text = image_text_replacement_fixup(config, "".join(text_parts)) input_ids = tokenizer( full_text, truncation=True, max_length=r.truncate, - add_special_tokens=r.add_special_tokens, + add_special_tokens=( + r.add_special_tokens if config.model_type != "paligemma" else False + ), )["input_ids"] max_length = max(max_length, len(input_ids)) - batch_tokenized_inputs.append(input_ids) - return batch_tokenized_inputs, image_inputs + if len(image_inputs) > 0: + img_start_token = vocab[image_texts[0][1]] + image_positions = cls.get_image_positions( + input_ids, image_texts, img_start_token, config, tokenizer + ) + else: + image_inputs = None + image_positions = None + + batch_tokenized_inputs.append(input_ids) + batch_image_inputs.append(image_inputs) + batch_image_positions.append(image_positions) + + return batch_tokenized_inputs, batch_image_inputs, batch_image_positions + + @classmethod + def get_image_positions( + cls, + input_ids: List[int], + image_texts: List[Tuple[int, str, str]], + img_start_token: int, + config, + tokenizer: PreTrainedTokenizerBase, + ) -> List[ImagePositions]: + image_positions = [] + num_images = len(image_texts) + + input_ids_t = torch.as_tensor(input_ids) + img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0] + num_tokens = input_ids_t.numel() + + last_pos = 0 + for i in range(num_images): + image_id, img_start_token_str, img_text = image_texts[i] + img_text = image_text_replacement_fixup(config, img_text) + + if config.model_type == "gemma3": + img_text = img_text.replace("\n\n", "") + + tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ][0] + length = tokens.numel() + + assert ( + length <= num_tokens + ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens" + + pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) + index = img_start_token_pos[pos] + assert torch.equal( + input_ids_t[index : index + length], tokens + ), "Image tokens not found in input_ids" + + is_embed = tokens == config.image_token_index + num_placeholder_tokens = int(is_embed.sum()) + if num_placeholder_tokens == length: + is_embed = None + + pos = ImagePositions( + offset=index, + length=length, + id=image_id, + num_placeholder_tokens=num_placeholder_tokens, + is_embed=is_embed, + ) + + image_positions.append(pos) + last_pos = index + length + + if ( + config.model_type == "idefics2" + and i + 1 != num_images + and input_ids[last_pos] == config.image_token_index + ): + fake_token = last_pos - 1 + fake_token_index = torch.searchsorted( + img_start_token_pos, fake_token, right=False + ) + img_start_token_pos[fake_token_index] = last_pos + image_texts[i + 1][2] = image_texts[i + 1][2][ + len(img_start_token_str) : + ] + + return image_positions @classmethod def from_pb_processor( @@ -354,33 +534,164 @@ class VlmCausalLMBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "VlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + batch_tokenized_inputs, image_inputs, image_positions = ( + cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config) ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "image_grid_thw" in image_inputs: - batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) - else: - batch.image_grid_thw = None - else: + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = [{} for _ in range(len(pb.requests))] + if len(image_inputs): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None return batch + def prepare_for_prefill(self): + super().prepare_for_prefill() + + self.has_image_inputs = False + self.cache_entries_to_free = [] + + self.pixel_values = [] + + assert ( + len(self.cache_lengths) + == len(self.input_lengths) + == len(self.prefilling_mask) + ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask" + + for i, ( + cache_length, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image_inputs = True + + if image_position.id not in self.encoder_cache[i]: + image_inputs = self.image_inputs[i][image_position.id] + self.pixel_values.append((i, image_position.id, image_inputs)) + + # Remove the image from the image_inputs + self.image_inputs[i][image_position.id] = None + + if not self.has_image_inputs: + self.pixel_values = None + self.pixel_attention_mask = None + self.image_sizes = None + self.image_grid_thw = None + else: + image_grid_thw_list = [ + x[2]["image_grid_thw"] + for x in self.pixel_values + if "image_grid_thw" in x[2] + ] + if image_grid_thw_list: + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to( + self.input_ids.device + ) + else: + self.image_grid_thw = None + + def update_encoder_cache(self, encoder_outputs, request_id, img_pos): + self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( + encoder_outputs, img_pos.is_embed + ) + + def gather_vision_embeds(self): + device = self.input_ids.device + chunks = [] + for ( + i, + cache_length, + input_length, + request_prefilling, + ) in zip( + range(len(self.requests)), + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + start_idx = max(cache_length - start_pos, 0) + end_idx = min(cache_length - start_pos + input_length, length) + + assert ( + image_position.id in self.encoder_cache[i] + ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}" + encoder_output = self.encoder_cache[i][image_position.id] + + is_embed = image_position.is_embed + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] + + from loguru import logger + + logger.info( + f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}" + ) + + embeds = gather_image_embeds( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + if embeds is not None: + chunks.append(embeds) + + if end_idx == length: + self.cache_entries_to_free.append((i, image_position.id)) + self.image_positions[i][image_position.id] = None + + if len(chunks) == 0: + return None + return torch.cat(chunks, dim=0).to(device) + + def free_encoder_cache(self): + for i, image_id in self.cache_entries_to_free: + self.encoder_cache[i].pop(image_id, None) + + self.cache_entries_to_free = [] + + # release any freed GPU memory immediately? + class VlmCausalLM(FlashCausalLM): def __init__( @@ -392,6 +703,7 @@ class VlmCausalLM(FlashCausalLM): batch_class=VlmCausalLMBatch, revision, trust_remote_code: bool, + support_chunking: bool = True, **kwargs, ): if PREFIX_CACHING: @@ -409,8 +721,7 @@ class VlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, - # FIXME: VLM do not work with context chunking yet - support_chunking=False, + support_chunking=support_chunking, **kwargs, ) @@ -418,6 +729,227 @@ class VlmCausalLM(FlashCausalLM): def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + config = getattr(self.model.config, "text_config", self.model.config) + if max_bs is None: + inputs_embeds = torch.zeros( + (bs, config.hidden_size), + device=self.device, + dtype=self.dtype, + ) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None + if ( # mrope have position_ids per section, if so repeat n times + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" + ): + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=inputs_embeds.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "inputs_embeds": inputs_embeds, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + self.model.forward( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + del seqlen + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + logits, speculative_logits = self.model.forward( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + torch.cuda.synchronize() + + def get_vision_embeds( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + image_sizes: torch.Tensor, + image_grid_thw: torch.Tensor, + ): + embeds = self.model.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + return embeds + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: Optional[torch.Tensor] = None, + ): + return self.model.get_inputs_embeds( + input_ids=input_ids, + vision_embeds=vision_embeds, + ) + + def encode_images(self, batch): + if batch.pixel_values is not None: + device = batch.input_ids.device + for request_id, image_id, image_input in batch.pixel_values: + pixel_values = image_input["pixel_values"].to(device) + + if "pixel_attention_mask" in image_input: + pixel_attention_mask = image_input["pixel_attention_mask"].to( + device + ) + else: + pixel_attention_mask = None + + if "image_sizes" in image_input: + image_sizes = image_input["image_sizes"].to(device) + else: + image_sizes = None + + if "image_grid_thw" in image_input: + image_grid_thw = image_input["image_grid_thw"].to(device) + else: + image_grid_thw = None + + encoder_outputs = self.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + batch.update_encoder_cache( + encoder_outputs, + request_id, + batch.image_positions[request_id][image_id], + ) + + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + + def set_inputs_embeds(self, batch): + if batch.has_image_inputs: + self.encode_images(batch) + vision_embeds = batch.gather_vision_embeds() + batch.has_image_inputs = False + else: + vision_embeds = None + + inputs_embeds = self.get_inputs_embeds( + batch.input_ids, vision_embeds=vision_embeds + ) + + batch.inputs_embeds = inputs_embeds + def forward( self, batch: VlmCausalLMBatch, @@ -468,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM): position_ids = new_position_ids else: input_ids = batch.input_ids + inputs_embeds = batch.inputs_embeds position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache @@ -485,13 +1018,17 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids + attention_mask = None + attention_mask_forward = None if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: - # Get the mask, needed for flashinfer. attention_mask = self.model.get_attention_mask( input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True - ).reshape(-1) - else: - attention_mask = None + ) + min_dtype = torch.finfo(self.dtype).min + attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( + input_ids.device + ) + attention_mask = attention_mask.reshape(-1) # Try to find an associated cuda graph bs = input_ids.shape[0] @@ -526,7 +1063,7 @@ class VlmCausalLM(FlashCausalLM): max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, @@ -536,26 +1073,17 @@ class VlmCausalLM(FlashCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - image_grid_thw=batch.image_grid_thw, + attention_mask=attention_mask_forward, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - if batch.image_grid_thw is not None: - batch.image_grid_thw = None + batch.image_grid_thw = None + batch.free_encoder_cache() return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( @@ -600,4 +1128,6 @@ class VlmCausalLM(FlashCausalLM): else None ) logits = cuda_graph["logits"][:bs] + + batch.free_encoder_cache() return logits, speculative_logits diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 935e0985..87c51eb9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -18,7 +18,6 @@ from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: - from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) @@ -26,7 +25,6 @@ try: from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch VLM_BATCH_TYPES = { - PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch, MllamaCausalLMBatch,