fix: debugging

This commit is contained in:
drbh 2024-05-09 14:23:56 +00:00 committed by Nicolas Patry
parent b07b53efba
commit 23294344c6
2 changed files with 15 additions and 7 deletions

View File

@ -165,11 +165,23 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
image_outputs = self.vision_tower(pixel_values) image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
# TODO: make sure to handle the specialized attention mask correctly image_features = image_features / (self.config.hidden_size**0.5)
inputs_embeds = self._merge_input_ids_with_image_features( inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids image_features, inputs_embeds, input_ids
) )
if input_ids.size(0) != 3000:
import ipdb
ipdb.set_trace()
## TODO: remove this
## load in values from reference
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
# inputs_embeds = torch.tensor(
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
# ).squeeze()
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -48,10 +48,6 @@ class SiglipVisionEmbeddings(nn.Module):
self.position_embedding = TensorParallelEmbedding( self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights prefix=f"{prefix}.position_embedding", weights=weights
) )
# TODO: remove this hack! figure out why off by one
self.position_embedding.weight = torch.nn.Parameter(
self.position_embedding.weight[:256, :]
)
self.register_buffer( self.register_buffer(
"position_ids", "position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)), torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
@ -288,7 +284,7 @@ class SiglipEncoderLayer(nn.Module):
class SiglipMultiheadAttentionPoolingHead(nn.Module): class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling.""" """Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig): def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__() super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@ -296,7 +292,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
config.hidden_size, config.num_attention_heads, batch_first=True config.hidden_size, config.num_attention_heads, batch_first=True
) )
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config) self.mlp = SiglipMLP(prefix, config, weights)
def forward(self, hidden_state): def forward(self, hidden_state):
batch_size = hidden_state.shape[0] batch_size = hidden_state.shape[0]