fix Qwen2 vl crash in benchmark

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-10 23:30:11 -07:00
parent f72b290020
commit b1ae4ad260
2 changed files with 26 additions and 22 deletions

View File

@ -441,22 +441,24 @@ class Qwen2_5VLAttention(nn.Module):
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa # execute sdpa
query = query.unsqueeze(0).transpose(1, 2) causal = False
key = key.unsqueeze(0).transpose(1, 2) query = query.transpose(0, 1)
value = value.unsqueeze(0).transpose(1, 2) key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA) fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op( attn_output = fsdpa_op(
query, query,
key, key,
value, value,
attn_mask=None, attn_mask=attention_mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=causal, is_causal=causal,
scale=None, scale=None,
@ -464,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module):
recompute_mode=None, recompute_mode=None,
valid_sequence_lengths=None, valid_sequence_lengths=None,
) )
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions # reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = attn_output.reshape(hidden_state.shape[0], -1)

View File

@ -112,22 +112,24 @@ class Qwen2VLAttention(nn.Module):
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa # execute sdpa
query = query.unsqueeze(0).transpose(1, 2) causal = False
key = key.unsqueeze(0).transpose(1, 2) query = query.transpose(0, 1)
value = value.unsqueeze(0).transpose(1, 2) key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA) fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op( attn_output = fsdpa_op(
query, query,
key, key,
value, value,
attn_mask=None, attn_mask=attention_mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=causal, is_causal=causal,
scale=None, scale=None,
@ -135,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
recompute_mode=None, recompute_mode=None,
valid_sequence_lengths=None, valid_sequence_lengths=None,
) )
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions # reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output) attn_output = self.proj(attn_output)