mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust get_position_ids if not available and add required args to signatures
This commit is contained in:
parent
77c81a29cb
commit
4f90db47be
@ -162,9 +162,7 @@ pub struct Qwen2Vl {
|
||||
impl Qwen2Vl {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let num_pixels = height * width;
|
||||
let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2);
|
||||
let start_and_end_tokens = 2;
|
||||
num_image_tokens + start_and_end_tokens
|
||||
num_pixels / self.vision_config.patch_size.pow(2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
@ -180,6 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
|
@ -360,7 +360,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if self.model.get_position_ids:
|
||||
if hasattr(self.model, "get_position_ids"):
|
||||
if position_ids.shape[0] != 1:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids.unsqueeze(0), batch.image_grid_thw
|
||||
|
Loading…
Reference in New Issue
Block a user