fix idefics

This commit is contained in:
Mohit Sharma 2025-04-19 14:39:24 +00:00
parent b86919a87a
commit 52e4186c2a
5 changed files with 200 additions and 114 deletions

View File

@ -116,11 +116,10 @@ class MistralAttention(torch.nn.Module):
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if hasattr(config, "head_dim"):
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,

View File

@ -733,6 +733,97 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.BoolTensor,
**kwargs,
):
assert pixel_values is not None
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: torch.BoolTensor = None,
**kwargs,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is None and pixel_values is not None:
vision_embeds = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -751,76 +842,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -1300,7 +1300,7 @@ class FlashCausalLM(Model):
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
@ -1899,9 +1899,6 @@ class FlashCausalLM(Model):
batch.prepare_for_prefill()
self.get_input_embeddings(batch)
from pdb import set_trace
set_trace()
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)

View File

@ -574,7 +574,8 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
return inputs
def get_vision_embeds(self, pixel_values, image_sizes=None):
def get_vision_embeds(self, pixel_values, **kwargs):
image_sizes = kwargs.get("image_sizes", None)
image_features = self.model.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
@ -586,10 +587,10 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
return projected_vision_flat
def get_input_embeds(self, input_ids, vision_embeddings=None):
def get_input_embeds(self, input_ids, vision_embeds=None):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if vision_embeddings is not None:
if vision_embeds is not None:
original_inputs_embeds_shape = inputs_embeds.shape
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
@ -600,17 +601,15 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != vision_embeddings.size(0):
if num_tokens_to_fill != vision_embeds.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
f"but multi_modal_projector returned {vision_embeds.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, vision_embeddings
)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds

View File

@ -376,7 +376,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image = image
else:
image = [image]
pixel_values = processor.image_processor(
[image], return_tensors="pt", **kwargs
)
@ -387,9 +386,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs
from pdb import set_trace
# pixel_values = processor.image_processor(
# all_images, return_tensors="pt", **kwargs
# )
set_trace()
batch_image_positions = []
batch_tokenized_inputs = []
max_length = 0
@ -410,6 +410,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type == "gemma3":
image_text = image_text.replace("\n\n", "")
elif config.model_type == "idefics2":
image_text = image_text_replacement_fixup(config, image_text)
image_tokens.append(
(
@ -434,9 +436,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
prev = 0
image_positions = []
idefics_replacement = False
config.image_token_index = (
config.image_token_index
if hasattr(config, "image_token_index")
else config.image_token_id
)
for image_id, id_token, tokens in image_tokens:
id_token = tokenizer.get_vocab()[id_token]
if config.model_type == "idefics2" and idefics_replacement:
id_token = config.image_token_index
tokens = tokens[1:]
length = len(tokens)
index = input_ids[prev:].index(id_token)
@ -463,6 +476,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_positions.append(pos)
prev = index + length
if config.model_type == "idefics2" and prev != len(input_ids):
if input_ids[prev] == config.image_token_id:
# means this is replaced by image_text_replacement_fixup
idefics_replacement = True
else:
idefics_replacement = False
if len(image_positions) > 0:
batch_image_positions[i] = image_positions
@ -540,40 +560,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.image_inputs[i][j] = None
if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = torch.cat(
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
).to(device)
self.pixel_values = [
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
]
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
self.pixel_attention_mask = torch.cat(
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
dim=0,
).to(device)
self.pixel_attention_mask = [
d["pixel_attention_mask"].to(device)
for d in scheduled_image_pixel_values
]
if "image_sizes" in scheduled_image_pixel_values[0]:
self.image_sizes = torch.cat(
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
).to(device)
self.image_sizes = [
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
]
if "image_grid_thw" in scheduled_image_pixel_values[0]:
self.image_grid_thw = torch.cat(
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
).to(device)
self.image_grid_thw = [
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
]
else:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs):
prev = 0
for i, input in self.scheduled_image_input:
length = input.num_placeholder_tokens
self.encoder_cache[i][input.id] = scatter_image_embeds(
encoder_outputs[prev : prev + length], input.is_embed
)
prev = prev + length
def update_encoder_cache(self, encoder_outputs, batch, input):
self.encoder_cache[batch][input.id] = scatter_image_embeds(
encoder_outputs, input.is_embed
)
def get_mm_embeddings(self):
device = self.input_ids.device
@ -675,22 +690,75 @@ class VlmCausalLM(FlashCausalLM):
def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: torch.Tensor,
image_sizes: torch.Tensor,
image_grid_thw: torch.Tensor,
):
embeds = self.model.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
return embeds
def get_input_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None,
):
return self.model.get_input_embeds(
input_ids=input_ids,
vision_embeds=vision_embeds,
)
def get_mm_embeddings(self, batch):
if batch.pixel_values is not None:
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
batch.update_encoder_cache(encoder_outputs)
batch.pixel_values = None
for i, image_input in batch.scheduled_image_input:
from pdb import set_trace
set_trace()
pixel_values = batch.pixel_values[i]
pixel_attention_mask = (
batch.pixel_attention_mask[i]
if batch.pixel_attention_mask is not None
else None
)
image_sizes = (
batch.image_sizes[i] if batch.image_sizes is not None else None
)
image_grid_thw = (
batch.image_grid_thw[i]
if batch.image_grid_thw is not None
else None
)
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(encoder_outputs, i, image_input)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch.get_mm_embeddings()
def get_input_embeddings(self, batch):
if batch.has_image:
vision_embeddings = self.get_mm_embeddings(batch)
vision_embeds = self.get_mm_embeddings(batch)
batch.has_image = False
else:
vision_embeddings = None
vision_embeds = None
inputs_embeds = self.get_input_embeds(
batch.input_ids, vision_embeddings=vision_embeddings
batch.input_ids, vision_embeds=vision_embeds
)
batch.inputs_embeds = inputs_embeds