mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: adjust pali gemma for post layer norm and small refactors
This commit is contained in:
parent
b84303e2e9
commit
6256b81baf
@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
|
@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
hidden_states, _ = encoder_layer(
|
hidden_states, _ = encoder_layer(
|
||||||
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
# self.post_layernorm = nn.LayerNorm.load(
|
|
||||||
# prefix=f"{prefix}.post_layernorm",
|
|
||||||
# weights=weights,
|
|
||||||
# eps=config.layer_norm_eps,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
@ -412,7 +402,6 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
)
|
)
|
||||||
last_hidden_state = encoder_outputs
|
last_hidden_state = encoder_outputs
|
||||||
# post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
|
Loading…
Reference in New Issue
Block a user