mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
fix config image_token_id error
This commit is contained in:
parent
419ecd0167
commit
60b8cb0e46
@ -97,7 +97,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -107,12 +106,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
@ -389,9 +389,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
max_length = 0
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
config.image_token_index = getattr(
|
||||
config, "image_token_index", config.image_token_id
|
||||
)
|
||||
if not hasattr(config, "image_token_index"):
|
||||
config.image_token_index = config.image_token_id
|
||||
|
||||
batch_tokenized_inputs: List[List[int]] = []
|
||||
batch_image_inputs: List[Optional[List[dict]]] = []
|
||||
|
Loading…
Reference in New Issue
Block a user