mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix test
This commit is contained in:
parent
60b8cb0e46
commit
534a16d50c
@ -798,6 +798,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -811,7 +812,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
|
@ -97,6 +97,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -109,7 +110,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
if cu_seqlen_prefill is not None:
|
||||
|
@ -819,6 +819,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -831,7 +832,6 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -561,6 +561,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -574,7 +575,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
# def get_input_embeddings(self):
|
||||
# return self.embed_tokens
|
||||
|
||||
# def set_inputs_embeds(self, value):
|
||||
# def set_input_embeddings(self, value):
|
||||
# self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
|
@ -270,6 +270,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -282,7 +283,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
# Unused for this model
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -947,6 +947,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -960,7 +961,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -525,6 +525,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -537,7 +538,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -919,6 +919,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
|
||||
def set_inputs_embeds(self, batch):
|
||||
if batch.has_image_inputs:
|
||||
@ -1005,10 +1008,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
attention_mask = None
|
||||
attention_mask_forward = None
|
||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||
# Get the mask, needed for flashinfer.
|
||||
has_image = (input_ids == self.model.config.image_token_index).any()
|
||||
|
||||
if has_image:
|
||||
attention_mask = self.model.get_attention_mask(
|
||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user