mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: adjust pali gemma for post layer norm and small refactors
This commit is contained in:
parent
b84303e2e9
commit
6256b81baf
@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||
prefix="vision_tower.vision_model.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.vision_config.layer_norm_eps,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
last_hidden_state = self.post_vision_tower_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multi_modal_projector(last_hidden_state)
|
||||
|
||||
# mask where image or padding tokens
|
||||
mask = input_ids == self.config.image_token_index
|
||||
|
@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states, _ = encoder_layer(
|
||||
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
# self.post_layernorm = nn.LayerNorm.load(
|
||||
# prefix=f"{prefix}.post_layernorm",
|
||||
# weights=weights,
|
||||
# eps=config.layer_norm_eps,
|
||||
# )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
@ -412,7 +402,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
# post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
|
Loading…
Reference in New Issue
Block a user