From b84303e2e9ea84fc80cc5345a623b7fd2dda1d3f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 24 Aug 2024 23:41:23 -0700 Subject: [PATCH] Fix: don't apply post layernorm in SiglipVisionTransformer This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813). This also makes Siglip consistent with the existing Clip implementation: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613 --- .../models/custom_modeling/siglip.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..feb40472 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -386,11 +386,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) + # self.post_layernorm = nn.LayerNorm.load( + # prefix=f"{prefix}.post_layernorm", + # weights=weights, + # eps=config.layer_norm_eps, + # ) def forward( self, @@ -412,10 +412,10 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) + # post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, )