fix: debugging

This commit is contained in:
drbh 2024-05-09 14:23:56 +00:00 committed by Nicolas Patry
parent b07b53efba
commit 23294344c6
2 changed files with 15 additions and 7 deletions

View File

@ -165,11 +165,23 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
# TODO: make sure to handle the specialized attention mask correctly
image_features = image_features / (self.config.hidden_size**0.5)
inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
if input_ids.size(0) != 3000:
import ipdb
ipdb.set_trace()
## TODO: remove this
## load in values from reference
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
# inputs_embeds = torch.tensor(
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
# ).squeeze()
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -48,10 +48,6 @@ class SiglipVisionEmbeddings(nn.Module):
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
# TODO: remove this hack! figure out why off by one
self.position_embedding.weight = torch.nn.Parameter(
self.position_embedding.weight[:256, :]
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
@ -288,7 +284,7 @@ class SiglipEncoderLayer(nn.Module):
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@ -296,7 +292,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.mlp = SiglipMLP(prefix, config, weights)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]