diff --git a/router/src/config.rs b/router/src/config.rs index eb16e88b..9c31e6e8 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -162,9 +162,7 @@ pub struct Qwen2Vl { impl Qwen2Vl { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { let num_pixels = height * width; - let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2); - let start_and_end_tokens = 2; - num_image_tokens + start_and_end_tokens + num_pixels / self.vision_config.patch_size.pow(2) } } diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb..b1f89eff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a829c374..923123d6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module): # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 32e9d334..df7366ea 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -180,6 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module): pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index fc813b30..df2c2a2c 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -360,7 +360,7 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if self.model.get_position_ids: + if hasattr(self.model, "get_position_ids"): if position_ids.shape[0] != 1: position_ids = self.model.get_position_ids( input_ids.unsqueeze(0), batch.image_grid_thw