fix: adjust get_position_ids if not available and add required args to signatures

This commit is contained in:
David Holtz 2024-10-29 15:26:41 +00:00
parent 77c81a29cb
commit 4f90db47be
5 changed files with 5 additions and 4 deletions

View File

@ -162,9 +162,7 @@ pub struct Qwen2Vl {
impl Qwen2Vl {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
let start_and_end_tokens = 2;
num_image_tokens + start_and_end_tokens
num_pixels / self.vision_config.patch_size.pow(2)
}
}

View File

@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:

View File

@ -180,6 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:

View File

@ -360,7 +360,7 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.get_position_ids:
if hasattr(self.model, "get_position_ids"):
if position_ids.shape[0] != 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw