From be8e60a9186df89137107d82b09d86eb12c92e04 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 21 Apr 2025 15:25:03 +0000 Subject: [PATCH] add improvements --- integration-tests/requirements.txt | 2 +- .../text_generation_server/models/__init__.py | 8 +- .../custom_modeling/flash_gemma3_modeling.py | 52 ++- .../flash_pali_gemma_modeling.py | 47 ++- .../models/custom_modeling/idefics3.py | 158 ++++---- .../models/custom_modeling/llava_next.py | 200 +++++----- .../models/custom_modeling/mllama.py | 1 + .../models/custom_modeling/qwen2_5_vl.py | 34 +- .../models/custom_modeling/qwen2_vl.py | 34 +- .../models/mllama_causal_lm.py | 6 + .../models/transformers_flash_vlm.py | 5 +- .../models/vlm_causal_lm.py | 356 ++++++++++++++++-- 12 files changed, 654 insertions(+), 249 deletions(-) diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index ca2dee93..fe4e929f 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -39,7 +39,7 @@ httpcore==1.0.7 # via httpx httpx==0.28.1 # via openai -huggingface-hub==0.29.3 +huggingface-hub==0.30.1 # via # text-generation-integration-tests (pyproject.toml) # text-generation diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 93a2f8bf..20d02bbe 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -128,9 +128,6 @@ try: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, ) - from text_generation_server.models.pali_gemma import ( - PaliGemmaBatch, - ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) @@ -1196,6 +1193,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) elif FLASH_TRANSFORMERS_BACKEND: from transformers import Gemma3ForConditionalGeneration as Gemma3Model @@ -1208,6 +1206,7 @@ def get_model( speculator=speculator, dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + support_chunking=False, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3")) @@ -1583,6 +1582,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) # TODO: Uncomment when transformers is refactored and cross attn is added # elif FLASH_TRANSFORMERS_BACKEND: @@ -1676,7 +1676,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, ) elif FLASH_TRANSFORMERS_BACKEND: from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel @@ -1689,7 +1688,6 @@ def get_model( speculator=speculator, dtype=torch.bfloat16, trust_remote_code=trust_remote_code, - batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma")) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 58afd643..35b29ce0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -700,6 +700,7 @@ class Gemma3ForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype def get_attention_mask( self, @@ -762,6 +763,38 @@ class Gemma3ForConditionalGeneration(nn.Module): else: return torch.where(full_attention_mask, 0, min_dtype).to(device) + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + **kwargs, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_model(pixel_values) + vision_outputs = self.post_vision_model_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multimodal_projector(vision_outputs) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + **kwargs, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # Replace the image token embeddings with the vision features + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + inputs_embeds[image_token_mask] = vision_embeds.view( + -1, vision_embeds.shape[-1] + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -781,26 +814,17 @@ class Gemma3ForConditionalGeneration(nn.Module): image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_model(pixel_values) - vision_outputs = self.post_vision_model_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multimodal_projector(vision_outputs) + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) - image_token_mask = (input_ids == self.config.image_token_index).to( - input_ids.device - ) - inputs_embeds[image_token_mask] = image_features.view( - -1, image_features.shape[-1] - ) + if torch.any(image_token_mask): attention_mask = self.get_attention_mask( input_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index b1f89eff..a6ceade8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -62,6 +62,37 @@ class PaliGemmaForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + **kwargs, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_tower(pixel_values) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) + image_features = image_features.view( + image_features.shape[0], image_features.shape[1], -1 + ) + return image_features + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + **kwargs, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + mask = input_ids == self.config.image_token_index + inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1]) + + return inputs_embeds def forward( self, @@ -81,27 +112,13 @@ class PaliGemmaForConditionalGeneration(nn.Module): image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_tower(pixel_values) - last_hidden_state = self.post_vision_tower_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multi_modal_projector(last_hidden_state) - - # mask where image or padding tokens - mask = input_ids == self.config.image_token_index - - # insert image features into input embeddings - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 580398cb..71e00442 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -476,6 +476,96 @@ class Idefics3ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.BoolTensor, + **kwargs, + ): + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: torch.BoolTensor = None, + **kwargs, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if vision_embeds is None and pixel_values is not None: + vision_embeds = self.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -497,74 +587,8 @@ class Idefics3ForConditionalGeneration(nn.Module): video_grid_thw: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + inputs_embeds: Optional[torch.Tensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - ) - - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index df7366ea..27861375 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -163,6 +163,116 @@ class LlavaNextForConditionalGeneration(nn.Module): ) return inputs_embeds + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.LongTensor] = None, + **kwargs, + ): + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + return image_features.view(-1, image_features.shape[-1]) + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + **kwargs, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if vision_embeds is None and pixel_values is not None: + vision_embeds = self.get_vision_embeds( + pixel_values=pixel_values, + image_sizes=image_sizes, + ) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -181,96 +291,8 @@ class LlavaNextForConditionalGeneration(nn.Module): image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width - ) - image_features = self.vision_tower(pixel_values) - - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state - - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." - ) - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index be0a4b5d..7d60c098 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -959,6 +959,7 @@ class MllamaForConditionalGeneration(nn.Module): # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + inputs_embeds=None, ): if cross_attention_states is not None: seqlen_q = len(image_indices) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 066de6a2..1d48cb34 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -922,6 +922,29 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): ) return position_ids + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + **kwargs, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None and len(vision_embeds) > 0: + inputs_embeds[input_ids == self.image_token_id] = vision_embeds + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -943,17 +966,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + inputs_embeds: Optional[torch.Tensor] = None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 26e6fede..7c5ed470 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -500,6 +500,29 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return position_ids + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_input_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + **kwargs, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None and len(vision_embeds) > 0: + inputs_embeds[input_ids == self.image_token_id] = vision_embeds + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -520,17 +543,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + inputs_embeds: Optional[torch.Tensor] = None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index c268ff9a..9078160b 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -29,6 +29,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + def prepare_for_prefill(self): + super(VlmCausalLMBatch, self).prepare_for_prefill() + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): @@ -196,6 +199,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLM(VlmCausalLM): + def get_input_embeddings(self, batch): + batch.inputs_embeds = None + def forward( self, batch: MllamaCausalLMBatch, diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index f7b84ffe..1f4a053b 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -163,6 +163,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): processor_kwargs=None, kv_cache_dtype: Optional[torch.dtype] = None, batch_class=VlmCausalLMBatch, + support_chunking: bool = True, ): self.batch_class = batch_class self.quantize = quantize @@ -304,7 +305,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): device=device, rank=rank, world_size=world_size, - support_chunking=True, + support_chunking=support_chunking, ) # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code @@ -339,6 +340,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): trust_remote_code: bool = False, batch_class: Optional[type] = VlmCausalLMBatch, processor_kwargs: Optional[dict] = None, + support_chunking: bool = True, ): return cls( model_id=model_id, @@ -350,6 +352,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): trust_remote_code=trust_remote_code, batch_class=batch_class, processor_kwargs=processor_kwargs, + support_chunking=support_chunking, ) def _model_forward( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bcc67134..5394e1c1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,7 +13,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor @@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way - n_rows = image_input["rows"][0][image_id] - n_cols = image_input["cols"][0][image_id] + n_rows = image_input[image_id]["rows"][0][0] + n_cols = image_input[image_id]["cols"][0][0] image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) @@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str ) return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][image_id] + height, width = image_input[image_id]["image_sizes"][0] num_features = get_number_of_features(height, width, config) log_master( @@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" @@ -344,8 +344,155 @@ class VlmCausalLMBatch(FlashCausalLMBatch): def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): - # Process images first. We need all of them so that the processor - # can make the image splits the same size. And we need the final + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + max_length = 0 + vocab = tokenizer.get_vocab() + config.image_token_index = ( + config.image_token_index + if hasattr(config, "image_token_index") + else config.image_token_id + ) + + batch_tokenized_inputs: List[List[int]] = [] + batch_image_inputs: List[Optional[List[dict]]] = [] + batch_image_positions: List[Optional[List[ImagePositions]]] = [] + + for i, r in enumerate(requests): + text_parts = [] + image_inputs = [] + image_texts = [] + + image_id = 0 + + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + text_parts.append(chunk.text) + continue + + if chunk_type != "image": + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + img = Image.open(BytesIO(chunk.image.data)) + + if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: + img = img.resize((img.width * 2, img.height * 2)) + + if config.model_type in {"paligemma"}: + img = img.convert("RGB") + + if config.model_type not in {"llava_next", "gemma3", "llama4"}: + img = [img] + + image_input = processor.image_processor( + [img], return_tensors="pt", **kwargs + ) + image_inputs.append(image_input) + + img_text, id_token_str = image_text_replacement( + processor, image_input, config, 0 + ) + + text_parts.append(img_text) + + image_texts.append([image_id, id_token_str, img_text]) + image_id += 1 + + full_text = image_text_replacement_fixup(config, "".join(text_parts)) + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + + if len(image_inputs) > 0: + img_start_token = vocab[image_texts[0][1]] + image_positions = cls.get_image_positions( + input_ids, image_texts, img_start_token, config, tokenizer + ) + else: + image_inputs = None + image_positions = None + + batch_tokenized_inputs.append(input_ids) + batch_image_inputs.append(image_inputs) + batch_image_positions.append(image_positions) + + return batch_tokenized_inputs, batch_image_inputs, batch_image_positions + + @classmethod + def get_image_positions( + cls, + input_ids: List[int], + image_texts: List[Tuple[int, str, str]], + img_start_token: int, + config, + tokenizer: PreTrainedTokenizerBase, + ) -> List[ImagePositions]: + image_positions = [] + num_images = len(image_texts) + + input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32) + img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0] + + last_pos = 0 + for i in range(num_images): + image_id, img_start_token_str, img_text = image_texts[i] + img_text = image_text_replacement_fixup(config, img_text) + if config.model_type == "gemma3": + img_text = img_text.replace("\n\n", "") + + tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"] + + pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) + index = img_start_token_pos[pos] + + is_embed = torch.tensor(tokens) == config.image_token_index + num_placeholder_tokens = is_embed.sum().item() + + length = len(tokens) + if num_placeholder_tokens == length: + is_embed = None + + pos = ImagePositions( + offset=index, + length=length, + id=image_id, + num_placeholder_tokens=num_placeholder_tokens, + is_embed=is_embed, + ) + + image_positions.append(pos) + last_pos = index + length + + if ( + config.model_type == "idefics2" + and i + 1 != num_images + and input_ids[last_pos] == config.image_token_index + ): + fake_token = last_pos - 1 + fake_token_index = torch.searchsorted( + img_start_token_pos, fake_token, right=False + ) + img_start_token_pos[fake_token_index] = last_pos + image_texts[i + 1][2] = image_texts[i + 1][2][ + len(img_start_token_str) : + ] + + return image_positions + + @classmethod + def batch_tokenized_inputs2( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): # sizes to insert correct number of image tokens. kwargs = {} if ( @@ -374,21 +521,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if config.model_type in {"llava_next", "gemma3", "llama4"}: image = image + elif config.model_type in {"paligemma"}: + image = image.convert("RGB") else: image = [image] - pixel_values = processor.image_processor( + image_input = processor.image_processor( [image], return_tensors="pt", **kwargs ) - image_inputs.append(pixel_values) + image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") if len(image_inputs) > 0: batch_image_inputs[i] = image_inputs - # pixel_values = processor.image_processor( - # all_images, return_tensors="pt", **kwargs - # ) batch_image_positions = [] batch_tokenized_inputs = [] @@ -554,29 +700,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if image_id not in self.encoder_cache[i]: self.pixel_values.append((i, image_position, image_inputs)) - # scheduled_image_pixel_values.append(image_inputs) self.image_inputs[i][j] = None - # if self.has_image and len(scheduled_image_pixel_values): - # self.pixel_values = [ - # d["pixel_values"].to(device) for d in scheduled_image_pixel_values - # ] - - # if "pixel_attention_mask" in scheduled_image_pixel_values[0]: - # self.pixel_attention_mask = [ - # d["pixel_attention_mask"].to(device) - # for d in scheduled_image_pixel_values - # ] - - # if "image_sizes" in scheduled_image_pixel_values[0]: - # self.image_sizes = [ - # d["image_sizes"].to(device) for d in scheduled_image_pixel_values - # ] - - # if "image_grid_thw" in scheduled_image_pixel_values[0]: - # self.image_grid_thw = [ - # d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values - # ] if not self.has_image: self.pixel_values = None self.pixel_attention_mask = None @@ -637,12 +762,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if is_embed is not None: is_embed = is_embed[start_idx:end_idx] + from loguru import logger + + logger.info( + f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}" + ) + mm_embeds_item = gather_image_embeds( encoder_output[start_idx:end_idx], is_embed=is_embed, ) - mm_embeds.append(mm_embeds_item) + if mm_embeds_item is not None: + mm_embeds.append(mm_embeds_item) + if len(mm_embeds) == 0: + return None return torch.cat(mm_embeds, dim=0).to(device) def free_encoder_cache(self): @@ -662,6 +796,7 @@ class VlmCausalLM(FlashCausalLM): batch_class=VlmCausalLMBatch, revision, trust_remote_code: bool, + support_chunking: bool = True, **kwargs, ): if PREFIX_CACHING: @@ -679,8 +814,7 @@ class VlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, - # FIXME: VLM do not work with context chunking yet - support_chunking=False, + support_chunking=support_chunking, **kwargs, ) @@ -688,6 +822,153 @@ class VlmCausalLM(FlashCausalLM): def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + input_embeds = torch.zeros( + (bs, self.model.config.text_config.hidden_size), + device=self.device, + dtype=self.dtype, + ) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None + if ( # mrope have position_ids per section, if so repeat n times + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" + ): + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "input_embeds": input_embeds, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + self.model.forward( + input_ids=input_ids, + inputs_embeds=input_embeds, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + del seqlen + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + inputs_embeds=input_embeds, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + torch.cuda.synchronize() + def get_vision_embeds( self, pixel_values: torch.Tensor, @@ -901,6 +1182,7 @@ class VlmCausalLM(FlashCausalLM): # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["input_embeds"][: inputs_embeds.shape[0]] = inputs_embeds cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged(