remove kwargs and redundant args

This commit is contained in:
Mohit Sharma 2025-04-24 13:33:22 +00:00
parent 36c5ec2abe
commit 3bb514ddd8
9 changed files with 46 additions and 103 deletions

View File

@ -766,7 +766,9 @@ class Gemma3ForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_model(pixel_values)
@ -781,7 +783,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
@ -797,7 +798,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -810,22 +810,13 @@ class Gemma3ForConditionalGeneration(nn.Module):
pixel_values: torch.FloatTensor = None,
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values:
attention_mask = self.get_attention_mask(
input_ids,
cu_seqlen_prefill,
inputs_embeds.dtype,
)
# Use flash attention for text-only input
# else:
# if cu_seqlen_prefill is not None:

View File

@ -67,7 +67,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
**kwargs,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
@ -84,7 +86,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)

View File

@ -736,8 +736,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.BoolTensor,
**kwargs,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
assert pixel_values is not None
batch_size, num_images, num_channels, height, width = pixel_values.shape
@ -805,16 +806,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: torch.BoolTensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
@ -826,7 +819,6 @@ class Idefics2ForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -836,12 +828,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(

View File

@ -479,8 +479,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.BoolTensor,
**kwargs,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
@ -547,16 +548,8 @@ class Idefics3ForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: torch.BoolTensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
@ -568,7 +561,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -578,14 +570,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):

View File

@ -166,8 +166,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
image_sizes: Optional[torch.LongTensor] = None,
**kwargs,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -256,14 +257,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
@ -275,7 +270,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -285,12 +279,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(

View File

@ -925,8 +925,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
@ -935,7 +936,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
@ -947,7 +947,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -957,14 +956,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
inputs_embeds: Optional[torch.Tensor] = None,
):

View File

@ -503,8 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
@ -513,7 +514,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
**kwargs,
):
inputs_embeds = self.embed_tokens(input_ids)
@ -525,7 +525,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -535,14 +534,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
attention_mask=None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model(

View File

@ -577,8 +577,13 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
return inputs
def get_vision_embeds(self, pixel_values, **kwargs):
image_sizes = kwargs.get("image_sizes", None)
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_features = self.model.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,

View File

@ -721,7 +721,6 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = [max_s] * bs
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
inputs_embeds = torch.zeros(
(bs, self.model.config.text_config.hidden_size),
device=self.device,
@ -760,7 +759,6 @@ class VlmCausalLM(FlashCausalLM):
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
@ -781,7 +779,7 @@ class VlmCausalLM(FlashCausalLM):
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
device=inputs_embeds.device,
block_tables=block_tables,
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
@ -793,7 +791,6 @@ class VlmCausalLM(FlashCausalLM):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
@ -822,7 +819,6 @@ class VlmCausalLM(FlashCausalLM):
max_k=max_s,
)
self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
@ -847,7 +843,6 @@ class VlmCausalLM(FlashCausalLM):
max_k=max_s,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=None,
@ -1007,14 +1002,21 @@ class VlmCausalLM(FlashCausalLM):
)
batch.position_ids = position_ids
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
# Get the mask, needed for flashinfer.
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
).reshape(-1)
batch.pixel_values = 1
else:
attention_mask = None
has_image = (input_ids == self.model.config.image_token_index).any()
if has_image:
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
)
min_dtype = torch.finfo(self.dtype).min
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
input_ids.device
)
attention_mask = attention_mask.reshape(-1)
# Try to find an associated cuda graph
bs = input_ids.shape[0]
@ -1049,7 +1051,6 @@ class VlmCausalLM(FlashCausalLM):
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
@ -1060,27 +1061,17 @@ class VlmCausalLM(FlashCausalLM):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
attention_mask=attention_mask_forward,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":