fix Qwen2 vl crash in benchmark

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-10 23:30:11 -07:00
parent f72b290020
commit b1ae4ad260
2 changed files with 26 additions and 22 deletions

View File

@ -441,22 +441,24 @@ class Qwen2_5VLAttention(nn.Module):
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -464,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)

View File

@ -112,22 +112,24 @@ class Qwen2VLAttention(nn.Module):
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -135,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)