mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix test
This commit is contained in:
parent
60b8cb0e46
commit
534a16d50c
@ -798,6 +798,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -811,7 +812,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
# Unused here
|
# Unused here
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
max_s += 1
|
max_s += 1
|
||||||
|
@ -97,6 +97,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -109,7 +110,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
# Unused here
|
# Unused here
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
@ -819,6 +819,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -831,7 +832,6 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
# Unused here
|
# Unused here
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -561,6 +561,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -574,7 +575,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
# def get_input_embeddings(self):
|
# def get_input_embeddings(self):
|
||||||
# return self.embed_tokens
|
# return self.embed_tokens
|
||||||
|
|
||||||
# def set_inputs_embeds(self, value):
|
# def set_input_embeddings(self, value):
|
||||||
# self.embed_tokens = value
|
# self.embed_tokens = value
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
@ -270,6 +270,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -282,7 +283,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
# Unused for this model
|
# Unused for this model
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -947,6 +947,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -960,7 +961,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -525,6 +525,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -537,7 +538,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_indices=None,
|
image_indices=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -919,6 +919,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
|
||||||
def set_inputs_embeds(self, batch):
|
def set_inputs_embeds(self, batch):
|
||||||
if batch.has_image_inputs:
|
if batch.has_image_inputs:
|
||||||
@ -1005,18 +1008,14 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
attention_mask_forward = None
|
attention_mask_forward = None
|
||||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||||
# Get the mask, needed for flashinfer.
|
attention_mask = self.model.get_attention_mask(
|
||||||
has_image = (input_ids == self.model.config.image_token_index).any()
|
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||||
|
)
|
||||||
if has_image:
|
min_dtype = torch.finfo(self.dtype).min
|
||||||
attention_mask = self.model.get_attention_mask(
|
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
||||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
input_ids.device
|
||||||
)
|
)
|
||||||
min_dtype = torch.finfo(self.dtype).min
|
attention_mask = attention_mask.reshape(-1)
|
||||||
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
|
||||||
input_ids.device
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask.reshape(-1)
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user