mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix position ids logic instantiation of idefics vision part
This commit is contained in:
parent
2f51645ad7
commit
5a6c5725ed
@ -88,12 +88,10 @@ class IdeficsVisionEmbeddings(nn.Module):
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
# self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
||||
)
|
||||
# self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user