mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix Qwen2 vl crash in benchmark
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f72b290020
commit
b1ae4ad260
@ -441,22 +441,24 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
causal = False
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attention_mask = torch.zeros(
|
||||
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||
] = True
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
@ -464,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
|
@ -112,22 +112,24 @@ class Qwen2VLAttention(nn.Module):
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
causal = False
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attention_mask = torch.zeros(
|
||||
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||
] = True
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
@ -135,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
|
Loading…
Reference in New Issue
Block a user