diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index a77836af..b0047f1e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -798,6 +798,7 @@ class Gemma3ForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -811,7 +812,6 @@ class Gemma3ForConditionalGeneration(nn.Module): # Unused here attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if cu_seqlen_prefill is not None: max_s += 1 diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index bfa1ff74..ef222c76 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -97,6 +97,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -109,7 +110,6 @@ class PaliGemmaForConditionalGeneration(nn.Module): # Unused here attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 01b7e50a..5c0d2fcc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -819,6 +819,7 @@ class Idefics2ForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -831,7 +832,6 @@ class Idefics2ForConditionalGeneration(nn.Module): # Unused here attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index d59c1c20..6d303c2c 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -561,6 +561,7 @@ class Idefics3ForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -574,7 +575,6 @@ class Idefics3ForConditionalGeneration(nn.Module): attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, image_indices=None, - inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 2da36ecc..9fc9bca6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel): # def get_input_embeddings(self): # return self.embed_tokens - # def set_inputs_embeds(self, value): + # def set_input_embeddings(self, value): # self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 5183a742..56a9565b 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -270,6 +270,7 @@ class LlavaNextForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -282,7 +283,6 @@ class LlavaNextForConditionalGeneration(nn.Module): # Unused for this model attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index ffe5357d..c1af3c28 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -947,6 +947,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -960,7 +961,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): attention_mask: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, image_indices=None, - inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 80f483e3..05d13786 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -525,6 +525,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): def forward( self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -537,7 +538,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, image_indices=None, attention_mask=None, - inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 764e02f6..b7fd88a6 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -919,6 +919,9 @@ class VlmCausalLM(FlashCausalLM): ) batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None def set_inputs_embeds(self, batch): if batch.has_image_inputs: @@ -1005,18 +1008,14 @@ class VlmCausalLM(FlashCausalLM): attention_mask = None attention_mask_forward = None if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: - # Get the mask, needed for flashinfer. - has_image = (input_ids == self.model.config.image_token_index).any() - - if has_image: - attention_mask = self.model.get_attention_mask( - input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True - ) - min_dtype = torch.finfo(self.dtype).min - attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( - input_ids.device - ) - attention_mask = attention_mask.reshape(-1) + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ) + min_dtype = torch.finfo(self.dtype).min + attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( + input_ids.device + ) + attention_mask = attention_mask.reshape(-1) # Try to find an associated cuda graph bs = input_ids.shape[0]