fix config image_token_id error

This commit is contained in:
Mohit Sharma 2025-04-25 11:54:26 +00:00
parent 419ecd0167
commit 60b8cb0e46
2 changed files with 3 additions and 8 deletions

View File

@ -97,7 +97,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -107,12 +106,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here # Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -389,9 +389,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
max_length = 0 max_length = 0
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
config.image_token_index = getattr( if not hasattr(config, "image_token_index"):
config, "image_token_index", config.image_token_id config.image_token_index = config.image_token_id
)
batch_tokenized_inputs: List[List[int]] = [] batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = [] batch_image_inputs: List[Optional[List[dict]]] = []