diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index e8485df6..4e83ba3c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -1397,7 +1397,9 @@ class Llama4ForConditionalGeneration(nn.Module): vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( + -1 + ) final_mask = special_image_mask.to(inputs_embeds.device) inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))