From 60b8cb0e46838ca8681a3ff1c4e438744ba7b86b Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 25 Apr 2025 11:54:26 +0000 Subject: [PATCH] fix config image_token_id error --- .../models/custom_modeling/flash_pali_gemma_modeling.py | 6 +----- server/text_generation_server/models/vlm_causal_lm.py | 5 ++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 6233f186..bfa1ff74 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -97,7 +97,6 @@ class PaliGemmaForConditionalGeneration(nn.Module): def forward( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -107,12 +106,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, # Unused here - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index def9bea8..764e02f6 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -389,9 +389,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): max_length = 0 vocab = tokenizer.get_vocab() - config.image_token_index = getattr( - config, "image_token_index", config.image_token_id - ) + if not hasattr(config, "image_token_index"): + config.image_token_index = config.image_token_id batch_tokenized_inputs: List[List[int]] = [] batch_image_inputs: List[Optional[List[dict]]] = []