Apply suggestions from code review

This commit is contained in:
Nicolas Patry 2023-09-26 15:07:38 +02:00 committed by GitHub
parent 1053e5d09a
commit bf2b92217f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -91,7 +91,7 @@ class IdeficsVisionEmbeddings(nn.Module):
self.position_embedding = TensorParallelEmbedding( self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights prefix="model.vision_model.embeddings.position_embedding", weights=weights
) )
self.position_ids = torch.arange(self.num_positions).expand((1, -1)) self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]