fix: prefer repeat over expand to avoid clone

This commit is contained in:
David Holtz 2024-11-01 21:11:15 +00:00
parent 11b3070ee7
commit 3c8d1f4b2f

View File

@ -472,7 +472,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
position_ids = ( position_ids = (
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device) torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
.view(1, 1, -1) .view(1, 1, -1)
.expand(3, batch_input_ids.shape[0], -1) .repeat(3, batch_input_ids.shape[0], 1)
) )
return position_ids return position_ids