From 7aebb953e29ed69361243b87d9cde4cf2eaace60 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 17:04:46 -0400 Subject: [PATCH] Fix: don't apply post layernorm in SiglipVisionTransformer (#2459) * Fix: don't apply post layernorm in SiglipVisionTransformer This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813). This also makes Siglip consistent with the existing Clip implementation: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613 * fix: adjust pali gemma for post layer norm and small refactors --------- Co-authored-by: Travis Addair --- .../custom_modeling/flash_pali_gemma_modeling.py | 10 +++++++++- .../models/custom_modeling/siglip.py | 13 +------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d10efb41..e08a2aad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): config=config.vision_config, weights=weights, ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) self.multi_modal_projector = TensorParallelColumnLinear.load( config, @@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None: pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) - image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) # mask where image or padding tokens mask = input_ids == self.config.image_token_index diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..95ac9ede 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): inputs_embeds, attention_mask: Optional[torch.Tensor] = None, ): - hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( @@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): - r""" - Returns: - - """ if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, )