Fix the image_token_id issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 22:11:42 +00:00
parent 50ecfc625a
commit f5aaa18d8e

View File

@ -1397,7 +1397,9 @@ class Llama4ForConditionalGeneration(nn.Module):
vision_flat = image_features.view(-1, image_features.size(-1)) vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat) projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
final_mask = special_image_mask.to(inputs_embeds.device) final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))