diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index bb0a55cf..7b715618 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -175,21 +175,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): config.pad_token_id if config.pad_token_id is not None else -1 ) - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids - ): - """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index - # Let's pray we have enabled enough slots ! - try: - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - except Exception as e: - raise RuntimeError( - f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -210,20 +195,21 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): input_ids, self.language_model.model.unscaled_embed_tokens ) - if pixel_values is not None: - pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) - - # merge text and images if pixel_values is not None and len(pixel_values) > 0: + # TODO: avoid these casts upstream + pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) + image_outputs = self.vision_tower(pixel_values) - selected_image_feature = image_outputs.last_hidden_state - image_features = self.multi_modal_projector(selected_image_feature) - # NOTE: image_features returns the exact values as transformers + image_features = self.multi_modal_projector(image_outputs.last_hidden_state) - # TODO: correctly merge inputs_embeds with image_features - merged_inputs_embeds = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids - ) + # TODO: now we scale them? maybe we can do this up or downstream + scaled_image_features = image_features / (2048**0.5) + + # mask where image or padding tokens + mask = input_ids == self.config.image_token_index | (input_ids == 2) + + # insert image features into input embeddings + inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1]) if input_ids.size(0) != 3000: # import ipdb @@ -231,6 +217,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): # ipdb.set_trace() pass + # NOTE: scale back up since we dont normalize inside the model like transformers + # TODO: simplify all the rescaling + inputs_embeds = inputs_embeds * (2048**0.5) + hidden_states = self.language_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -242,6 +232,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): max_s=max_s, ) + if input_ids.size(0) != 3000: + # import ipdb; ipdb.set_trace() + pass + if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.language_model.lm_head(hidden_states)