fix: create position ids for text only input

This commit is contained in:
David Holtz 2024-11-01 15:54:52 +00:00
parent 01dacf8e8f
commit 11b3070ee7

View File

@ -468,7 +468,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[:, i, :] = llm_positions.to(position_ids.device) position_ids[:, i, :] = llm_positions.to(position_ids.device)
else:
position_ids = (
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
.view(1, 1, -1)
.expand(3, batch_input_ids.shape[0], -1)
)
return position_ids return position_ids
def forward( def forward(