From 23294344c6060899a2a3bbe8324a20fd3996e9e2 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 9 May 2024 14:23:56 +0000 Subject: [PATCH] fix: debugging --- .../custom_modeling/flash_pali_gemma_modeling.py | 14 +++++++++++++- .../models/custom_modeling/siglip.py | 8 ++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 3e091ebd..2cf51ea1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -165,11 +165,23 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - # TODO: make sure to handle the specialized attention mask correctly + image_features = image_features / (self.config.hidden_size**0.5) inputs_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) + if input_ids.size(0) != 3000: + import ipdb + + ipdb.set_trace() + + ## TODO: remove this + ## load in values from reference + # tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz") + # inputs_embeds = torch.tensor( + # tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype + # ).squeeze() + hidden_states = self.language_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 81795f86..ad3aad45 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -48,10 +48,6 @@ class SiglipVisionEmbeddings(nn.Module): self.position_embedding = TensorParallelEmbedding( prefix=f"{prefix}.position_embedding", weights=weights ) - # TODO: remove this hack! figure out why off by one - self.position_embedding.weight = torch.nn.Parameter( - self.position_embedding.weight[:256, :] - ) self.register_buffer( "position_ids", torch.arange(self.num_positions, device=weights.device).expand((1, -1)), @@ -288,7 +284,7 @@ class SiglipEncoderLayer(nn.Module): class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, config: SiglipVisionConfig): + def __init__(self, prefix, config: SiglipVisionConfig, weights): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) @@ -296,7 +292,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): config.hidden_size, config.num_attention_heads, batch_first=True ) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) + self.mlp = SiglipMLP(prefix, config, weights) def forward(self, hidden_state): batch_size = hidden_state.shape[0]