mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Apply suggestions from code review
This commit is contained in:
parent
1053e5d09a
commit
bf2b92217f
@ -91,7 +91,7 @@ class IdeficsVisionEmbeddings(nn.Module):
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
||||
)
|
||||
self.position_ids = torch.arange(self.num_positions).expand((1, -1))
|
||||
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user