Fix position ids logic instantiation of idefics vision part

This commit is contained in:
Victor SANH 2023-09-26 14:22:23 +02:00 committed by GitHub
parent 2f51645ad7
commit 5a6c5725ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -88,12 +88,10 @@ class IdeficsVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
# self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.position_embedding = TensorParallelEmbedding( self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights prefix="model.vision_model.embeddings.position_embedding", weights=weights
) )
# self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]