mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: debugging
This commit is contained in:
parent
b07b53efba
commit
23294344c6
@ -165,11 +165,23 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
# TODO: make sure to handle the specialized attention mask correctly
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids
|
||||
)
|
||||
|
||||
if input_ids.size(0) != 3000:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
## TODO: remove this
|
||||
## load in values from reference
|
||||
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
|
||||
# inputs_embeds = torch.tensor(
|
||||
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
||||
# ).squeeze()
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -48,10 +48,6 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
# TODO: remove this hack! figure out why off by one
|
||||
self.position_embedding.weight = torch.nn.Parameter(
|
||||
self.position_embedding.weight[:256, :]
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||
@ -288,7 +284,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
@ -296,7 +292,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
self.mlp = SiglipMLP(prefix, config, weights)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
batch_size = hidden_state.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user