fix: adjust image and text merge logic

This commit is contained in:
drbh 2024-05-10 16:13:11 +00:00 committed by Nicolas Patry
parent d503007fcf
commit 36fb4b5a7a

View File

@ -175,21 +175,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -210,20 +195,21 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
input_ids, self.language_model.model.unscaled_embed_tokens
)
if pixel_values is not None:
if pixel_values is not None and len(pixel_values) > 0:
# TODO: avoid these casts upstream
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
# merge text and images
if pixel_values is not None and len(pixel_values) > 0:
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
# NOTE: image_features returns the exact values as transformers
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: correctly merge inputs_embeds with image_features
merged_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
# TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (2048**0.5)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2)
# insert image features into input embeddings
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
if input_ids.size(0) != 3000:
# import ipdb
@ -231,6 +217,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
# ipdb.set_trace()
pass
# NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling
inputs_embeds = inputs_embeds * (2048**0.5)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
@ -242,6 +232,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
max_s=max_s,
)
if input_ids.size(0) != 3000:
# import ipdb; ipdb.set_trace()
pass
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)