llama4 and some issue fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-08 23:56:38 -07:00
parent b09d4cc142
commit 93e5e35f9d
3 changed files with 62 additions and 53 deletions

View File

@ -1356,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module):
hidden_state = self.vision_model(pixel_values)
return hidden_state
def forward(
def get_vision_embeds(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.config.vision_config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
image_sizes=image_sizes,
)
vision_flat = image_features.view(-1, image_features.size(-1))
image_features = self.multi_modal_projector(vision_flat)
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask=None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
slots: torch.Tensor = None,
seqlen: Seqlen = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
image_sizes: torch.Tensor = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_padding_mask(input_ids, pad_token_id=0):
return (input_ids != pad_token_id).long()
attention_mask = _get_padding_mask(input_ids)
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_config.vision_feature_select_strategy
)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
@ -1414,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module):
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != projected_vision_flat.size(0):
if num_tokens_to_fill != vision_embeds.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
f"but multi_modal_projector returned {vision_embeds.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
slots: torch.Tensor = None,
seqlen: Seqlen = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
logits, speculative_logits = self.text_model(
inputs_embeds,

View File

@ -285,7 +285,7 @@ def scatter_image_embeds(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
placeholders[is_embed.to(embeds.device)] = embeds
return placeholders
@ -294,7 +294,7 @@ def gather_image_embeds(
) -> Optional[torch.Tensor]:
if is_embed is None:
return embeds
sel = embeds[is_embed]
sel = embeds[is_embed.to(embeds.device)]
return sel if sel.numel() else None
@ -553,8 +553,12 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None
return batch
def prepare_for_prefill(self):
super().prepare_for_prefill()
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super().prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
self.has_image_inputs = False
self.cache_entries_to_free = []
@ -1003,6 +1007,9 @@ class FlashVlmCausalLM(FlashCausalLM):
input_ids.device
)
attention_mask = attention_mask.reshape(-1)
if self.model.config.model_type == "llama4":
attention_mask = (input_ids != 0).long()
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache

View File

@ -46,6 +46,13 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0):