fix: typo and lint

This commit is contained in:
drbh 2024-05-10 16:17:59 +00:00 committed by Nicolas Patry
parent 36fb4b5a7a
commit 4df1b25ddb

View File

@ -203,13 +203,15 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (2048**0.5)
scaled_image_features = image_features / (self.config.hidden_size**0.5)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2)
# insert image features into input embeddings
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1]
)
if input_ids.size(0) != 3000:
# import ipdb
@ -219,7 +221,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
# NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling
inputs_embeds = inputs_embeds * (2048**0.5)
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,