From 4df1b25ddbfa522624f148dfedbc0f11c85f7ead Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 10 May 2024 16:17:59 +0000 Subject: [PATCH] fix: typo and lint --- .../custom_modeling/flash_pali_gemma_modeling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 7b715618..ffafdd9f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -198,18 +198,20 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None and len(pixel_values) > 0: # TODO: avoid these casts upstream pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) - + image_outputs = self.vision_tower(pixel_values) image_features = self.multi_modal_projector(image_outputs.last_hidden_state) # TODO: now we scale them? maybe we can do this up or downstream - scaled_image_features = image_features / (2048**0.5) + scaled_image_features = image_features / (self.config.hidden_size**0.5) # mask where image or padding tokens mask = input_ids == self.config.image_token_index | (input_ids == 2) # insert image features into input embeddings - inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1]) + inputs_embeds[mask] = scaled_image_features.view( + -1, scaled_image_features.shape[-1] + ) if input_ids.size(0) != 3000: # import ipdb @@ -219,7 +221,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): # NOTE: scale back up since we dont normalize inside the model like transformers # TODO: simplify all the rescaling - inputs_embeds = inputs_embeds * (2048**0.5) + inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5) hidden_states = self.language_model.model( inputs_embeds=inputs_embeds,