This commit is contained in:
Mohit Sharma 2025-04-28 09:40:23 +00:00
parent 60b8cb0e46
commit 534a16d50c
9 changed files with 19 additions and 20 deletions

View File

@ -798,6 +798,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -811,7 +812,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if cu_seqlen_prefill is not None:
max_s += 1

View File

@ -97,6 +97,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -109,7 +110,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:

View File

@ -819,6 +819,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -831,7 +832,6 @@ class Idefics2ForConditionalGeneration(nn.Module):
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -561,6 +561,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -574,7 +575,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
# def get_input_embeddings(self):
# return self.embed_tokens
# def set_inputs_embeds(self, value):
# def set_input_embeddings(self, value):
# self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask

View File

@ -270,6 +270,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -282,7 +283,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
# Unused for this model
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -947,6 +947,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -960,7 +961,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -525,6 +525,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -537,7 +538,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
attention_mask=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -919,6 +919,9 @@ class VlmCausalLM(FlashCausalLM):
)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
@ -1005,10 +1008,6 @@ class VlmCausalLM(FlashCausalLM):
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
# Get the mask, needed for flashinfer.
has_image = (input_ids == self.model.config.image_token_index).any()
if has_image:
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
)