fix: typo and lint

This commit is contained in:
drbh 2024-05-10 16:17:59 +00:00 committed by Nicolas Patry
parent 36fb4b5a7a
commit 4df1b25ddb

View File

@ -198,18 +198,20 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None and len(pixel_values) > 0:
# TODO: avoid these casts upstream
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (2048**0.5)
scaled_image_features = image_features / (self.config.hidden_size**0.5)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2)
# insert image features into input embeddings
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1]
)
if input_ids.size(0) != 3000:
# import ipdb
@ -219,7 +221,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
# NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling
inputs_embeds = inputs_embeds * (2048**0.5)
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,