fix: adjust pali gemma for post layer norm and small refactors

This commit is contained in:
drbh 2024-08-26 19:35:39 +00:00
parent b84303e2e9
commit 6256b81baf
2 changed files with 9 additions and 12 deletions

View File

@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
config=config.vision_config, config=config.vision_config,
weights=weights, weights=weights,
) )
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load( self.multi_modal_projector = TensorParallelColumnLinear.load(
config, config,
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None: if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values) image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state) last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens # mask where image or padding tokens
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index

View File

@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer( hidden_states, _ = encoder_layer(
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
# self.post_layernorm = nn.LayerNorm.load(
# prefix=f"{prefix}.post_layernorm",
# weights=weights,
# eps=config.layer_norm_eps,
# )
def forward( def forward(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
): ):
r"""
Returns:
"""
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -412,7 +402,6 @@ class SiglipVisionTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
) )
last_hidden_state = encoder_outputs last_hidden_state = encoder_outputs
# post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,