mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
remove kwargs and redundant args
This commit is contained in:
parent
36c5ec2abe
commit
3bb514ddd8
@ -766,7 +766,9 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
image_outputs = self.vision_model(pixel_values)
|
||||
@ -781,7 +783,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
@ -797,7 +798,6 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -810,22 +810,13 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values:
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids,
|
||||
cu_seqlen_prefill,
|
||||
inputs_embeds.dtype,
|
||||
)
|
||||
# Use flash attention for text-only input
|
||||
# else:
|
||||
# if cu_seqlen_prefill is not None:
|
||||
|
@ -67,7 +67,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
@ -84,7 +86,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
|
@ -736,8 +736,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: torch.BoolTensor,
|
||||
**kwargs,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
assert pixel_values is not None
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
@ -805,16 +806,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: torch.BoolTensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
@ -826,7 +819,6 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -836,12 +828,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
|
@ -479,8 +479,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: torch.BoolTensor,
|
||||
**kwargs,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
@ -547,16 +548,8 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: torch.BoolTensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
@ -568,7 +561,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -578,14 +570,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
@ -166,8 +166,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
@ -256,14 +257,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
@ -275,7 +270,6 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -285,12 +279,9 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model.model(
|
||||
|
@ -925,8 +925,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
@ -935,7 +936,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -947,7 +947,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -957,14 +956,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# Unused in this model
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
@ -503,8 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
@ -513,7 +514,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -525,7 +525,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -535,14 +534,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.text_model(
|
||||
|
@ -577,8 +577,13 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
return inputs
|
||||
|
||||
def get_vision_embeds(self, pixel_values, **kwargs):
|
||||
image_sizes = kwargs.get("image_sizes", None)
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
image_features = self.model.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||
|
@ -721,7 +721,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
inputs_embeds = torch.zeros(
|
||||
(bs, self.model.config.text_config.hidden_size),
|
||||
device=self.device,
|
||||
@ -760,7 +759,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
raise RuntimeError(
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
@ -781,7 +779,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
device=inputs_embeds.device,
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
@ -793,7 +791,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
@ -822,7 +819,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_k=max_s,
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -847,7 +843,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_k=max_s,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -1007,14 +1002,21 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
attention_mask = None
|
||||
attention_mask_forward = None
|
||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||
# Get the mask, needed for flashinfer.
|
||||
has_image = (input_ids == self.model.config.image_token_index).any()
|
||||
|
||||
if has_image:
|
||||
attention_mask = self.model.get_attention_mask(
|
||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||
).reshape(-1)
|
||||
batch.pixel_values = 1
|
||||
else:
|
||||
attention_mask = None
|
||||
)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
||||
input_ids.device
|
||||
)
|
||||
attention_mask = attention_mask.reshape(-1)
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
@ -1049,7 +1051,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
@ -1060,27 +1061,17 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
attention_mask=attention_mask_forward,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
batch.free_encoder_cache()
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["inputs_embeds"][: inputs_embeds.shape[0]] = inputs_embeds
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
|
Loading…
Reference in New Issue
Block a user