diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index c6b68f33..3b30f2e0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -1356,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module): hidden_state = self.vision_model(pixel_values) return hidden_state - def forward( + def get_vision_embeds( self, - input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.config.vision_config.vision_feature_layer, + vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy, + image_sizes=image_sizes, + ) + vision_flat = image_features.view(-1, image_features.size(-1)) + image_features = self.multi_modal_projector(vision_flat) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, pixel_values: torch.FloatTensor = None, - pixel_attention_mask=None, - position_ids: Optional[torch.LongTensor] = None, - cu_seqlen_prefill: Optional[torch.Tensor] = None, - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None, - slots: torch.Tensor = None, - seqlen: Seqlen = None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, List[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, - image_sizes: torch.Tensor = None, - lm_head_indices: Optional[torch.Tensor] = None, - adapter_data: Optional[torch.Tensor] = None, - **lm_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - - def _get_padding_mask(input_ids, pad_token_id=0): - return (input_ids != pad_token_id).long() - - attention_mask = _get_padding_mask(input_ids) - attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1) + image_sizes: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.text_model.model.embed_tokens(input_ids) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_config.vision_feature_select_strategy - ) - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist original_inputs_embeds_shape = inputs_embeds.shape - - vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( -1 ) @@ -1414,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module): final_mask_1d = final_mask[..., 0].reshape(-1) num_tokens_to_fill = final_mask_1d.sum() - if num_tokens_to_fill != projected_vision_flat.size(0): + if num_tokens_to_fill != vision_embeds.size(0): raise ValueError( f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + f"but multi_modal_projector returned {vision_embeds.size(0)}" ) expanded_mask = final_mask_1d.unsqueeze(-1).expand( -1, inputs_embeds.size(-1) ) - inputs_embeds = inputs_embeds.masked_scatter( - expanded_mask, projected_vision_flat - ) + inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds) inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlen_prefill: Optional[torch.Tensor] = None, + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None, + slots: torch.Tensor = None, + seqlen: Seqlen = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + **lm_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: logits, speculative_logits = self.text_model( inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index dfb49dea..f0129013 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -285,7 +285,7 @@ def scatter_image_embeds( (is_embed.shape[0], embeds.shape[-1]), fill_value=torch.nan, ) - placeholders[is_embed] = embeds + placeholders[is_embed.to(embeds.device)] = embeds return placeholders @@ -294,7 +294,7 @@ def gather_image_embeds( ) -> Optional[torch.Tensor]: if is_embed is None: return embeds - sel = embeds[is_embed] + sel = embeds[is_embed.to(embeds.device)] return sel if sel.numel() else None @@ -553,8 +553,12 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): batch.image_grid_thw = None return batch - def prepare_for_prefill(self): - super().prepare_for_prefill() + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super().prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) self.has_image_inputs = False self.cache_entries_to_free = [] @@ -1003,6 +1007,9 @@ class FlashVlmCausalLM(FlashCausalLM): input_ids.device ) attention_mask = attention_mask.reshape(-1) + if self.model.config.model_type == "llama4": + attention_mask = (input_ids != 0).long() + attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1) if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index f8089e4c..02b8935d 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -46,6 +46,13 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super(FlashVlmCausalLMBatch, self).prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches, padded_total_bs: int = 0):