mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: debugging
This commit is contained in:
parent
b07b53efba
commit
23294344c6
@ -165,11 +165,23 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
selected_image_feature = image_outputs.last_hidden_state
|
selected_image_feature = image_outputs.last_hidden_state
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
# TODO: make sure to handle the specialized attention mask correctly
|
image_features = image_features / (self.config.hidden_size**0.5)
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids
|
image_features, inputs_embeds, input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if input_ids.size(0) != 3000:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
|
||||||
|
## TODO: remove this
|
||||||
|
## load in values from reference
|
||||||
|
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
|
||||||
|
# inputs_embeds = torch.tensor(
|
||||||
|
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
||||||
|
# ).squeeze()
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -48,10 +48,6 @@ class SiglipVisionEmbeddings(nn.Module):
|
|||||||
self.position_embedding = TensorParallelEmbedding(
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
prefix=f"{prefix}.position_embedding", weights=weights
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
)
|
)
|
||||||
# TODO: remove this hack! figure out why off by one
|
|
||||||
self.position_embedding.weight = torch.nn.Parameter(
|
|
||||||
self.position_embedding.weight[:256, :]
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"position_ids",
|
"position_ids",
|
||||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||||
@ -288,7 +284,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||||
"""Multihead Attention Pooling."""
|
"""Multihead Attention Pooling."""
|
||||||
|
|
||||||
def __init__(self, config: SiglipVisionConfig):
|
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||||
@ -296,7 +292,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|||||||
config.hidden_size, config.num_attention_heads, batch_first=True
|
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||||
)
|
)
|
||||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(config)
|
self.mlp = SiglipMLP(prefix, config, weights)
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
batch_size = hidden_state.shape[0]
|
batch_size = hidden_state.shape[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user