From 3bb514ddd827a74cbf11b1c8c03876bfeab28347 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 24 Apr 2025 13:33:22 +0000 Subject: [PATCH] remove kwargs and redundant args --- .../custom_modeling/flash_gemma3_modeling.py | 15 ++----- .../flash_pali_gemma_modeling.py | 5 ++- .../models/custom_modeling/idefics2.py | 19 ++------- .../models/custom_modeling/idefics3.py | 21 ++-------- .../models/custom_modeling/llava_next.py | 17 ++------ .../models/custom_modeling/qwen2_5_vl.py | 12 ++---- .../models/custom_modeling/qwen2_vl.py | 12 ++---- .../models/transformers_flash_vlm.py | 9 ++++- .../models/vlm_causal_lm.py | 39 +++++++------------ 9 files changed, 46 insertions(+), 103 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index d5ba9a8a..a77836af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -766,7 +766,9 @@ class Gemma3ForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, - **kwargs, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): pixel_values = pixel_values.to(dtype=self.dtype) image_outputs = self.vision_model(pixel_values) @@ -781,7 +783,6 @@ class Gemma3ForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - **kwargs, ): inputs_embeds = self.text_model.embed_tokens(input_ids) @@ -797,7 +798,6 @@ class Gemma3ForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -810,22 +810,13 @@ class Gemma3ForConditionalGeneration(nn.Module): pixel_values: torch.FloatTensor = None, # Unused here attention_mask: Optional[torch.BoolTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 - if pixel_values: - attention_mask = self.get_attention_mask( - input_ids, - cu_seqlen_prefill, - inputs_embeds.dtype, - ) # Use flash attention for text-only input # else: # if cu_seqlen_prefill is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d1117e39..6233f186 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -67,7 +67,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, - **kwargs, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): pixel_values = pixel_values.to(dtype=self.dtype) image_outputs = self.vision_tower(pixel_values) @@ -84,7 +86,6 @@ class PaliGemmaForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - **kwargs, ): inputs_embeds = self.text_model.embed_tokens(input_ids) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index bb4c6ca3..01b7e50a 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -736,8 +736,9 @@ class Idefics2ForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, - pixel_attention_mask: torch.BoolTensor, - **kwargs, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): assert pixel_values is not None batch_size, num_images, num_channels, height, width = pixel_values.shape @@ -805,16 +806,8 @@ class Idefics2ForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: torch.BoolTensor = None, - **kwargs, ): inputs_embeds = self.text_model.embed_tokens(input_ids) - if vision_embeds is None and pixel_values is not None: - vision_embeds = self.get_vision_embeds( - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - ) if vision_embeds is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images @@ -826,7 +819,6 @@ class Idefics2ForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -836,12 +828,9 @@ class Idefics2ForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model.model( diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index d4165502..d59c1c20 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -479,8 +479,9 @@ class Idefics3ForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, - pixel_attention_mask: torch.BoolTensor, - **kwargs, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): batch_size, num_images, num_channels, height, width = pixel_values.shape all_states = [] @@ -547,16 +548,8 @@ class Idefics3ForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: torch.BoolTensor = None, - **kwargs, ): inputs_embeds = self.text_model.embed_tokens(input_ids) - if vision_embeds is None and pixel_values is not None: - vision_embeds = self.get_vision_embeds( - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - ) if vision_embeds is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images @@ -568,7 +561,6 @@ class Idefics3ForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -578,14 +570,9 @@ class Idefics3ForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, inputs_embeds: Optional[torch.Tensor] = None, ): diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index bb6da022..5183a742 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -166,8 +166,9 @@ class LlavaNextForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, - image_sizes: Optional[torch.LongTensor] = None, - **kwargs, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -256,14 +257,8 @@ class LlavaNextForConditionalGeneration(nn.Module): vision_embeds: torch.Tensor = None, pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, - **kwargs, ): inputs_embeds = self.text_model.embed_tokens(input_ids) - if vision_embeds is None and pixel_values is not None: - vision_embeds = self.get_vision_embeds( - pixel_values=pixel_values, - image_sizes=image_sizes, - ) if vision_embeds is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images @@ -275,7 +270,6 @@ class LlavaNextForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -285,12 +279,9 @@ class LlavaNextForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, # Unused for this model - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model.model( diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index addb9032..ffe5357d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -925,8 +925,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, - **kwargs, ): image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) return image_embeds @@ -935,7 +936,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - **kwargs, ): inputs_embeds = self.embed_tokens(input_ids) @@ -947,7 +947,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -957,14 +956,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, # Unused in this model - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, inputs_embeds: Optional[torch.Tensor] = None, ): diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 0ca41c1d..80f483e3 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -503,8 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_vision_embeds( self, pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, - **kwargs, ): image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) return image_embeds @@ -513,7 +514,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, - **kwargs, ): inputs_embeds = self.embed_tokens(input_ids) @@ -525,7 +525,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -535,14 +534,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + attention_mask=None, inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model( diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 518e5972..98644836 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -577,8 +577,13 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) return inputs - def get_vision_embeds(self, pixel_values, **kwargs): - image_sizes = kwargs.get("image_sizes", None) + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): image_features = self.model.get_image_features( pixel_values=pixel_values, vision_feature_layer=self.model.config.vision_config.vision_feature_layer, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 872418e8..974fada4 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -721,7 +721,6 @@ class VlmCausalLM(FlashCausalLM): input_lengths = [max_s] * bs cache_lengths = [0] * bs if max_bs is None: - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) inputs_embeds = torch.zeros( (bs, self.model.config.text_config.hidden_size), device=self.device, @@ -760,7 +759,6 @@ class VlmCausalLM(FlashCausalLM): raise RuntimeError( "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) - input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer": @@ -781,7 +779,7 @@ class VlmCausalLM(FlashCausalLM): ) last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( - device=input_ids.device, + device=inputs_embeds.device, block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, @@ -793,7 +791,6 @@ class VlmCausalLM(FlashCausalLM): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { - "input_ids": input_ids, "inputs_embeds": inputs_embeds, "position_ids": position_ids, "kv_cache": self.kv_cache, @@ -822,7 +819,6 @@ class VlmCausalLM(FlashCausalLM): max_k=max_s, ) self.model.forward( - input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, @@ -847,7 +843,6 @@ class VlmCausalLM(FlashCausalLM): max_k=max_s, ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, @@ -1007,14 +1002,21 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids + attention_mask = None + attention_mask_forward = None if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: # Get the mask, needed for flashinfer. - attention_mask = self.model.get_attention_mask( - input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True - ).reshape(-1) - batch.pixel_values = 1 - else: - attention_mask = None + has_image = (input_ids == self.model.config.image_token_index).any() + + if has_image: + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ) + min_dtype = torch.finfo(self.dtype).min + attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( + input_ids.device + ) + attention_mask = attention_mask.reshape(-1) # Try to find an associated cuda graph bs = input_ids.shape[0] @@ -1049,7 +1051,6 @@ class VlmCausalLM(FlashCausalLM): max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -1060,27 +1061,17 @@ class VlmCausalLM(FlashCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - image_grid_thw=batch.image_grid_thw, + attention_mask=attention_mask_forward, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - if batch.image_grid_thw is not None: - batch.image_grid_thw = None batch.free_encoder_cache() return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids if ATTENTION == "flashinfer":