mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
fix Qwen2 vl crash in benchmark
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f72b290020
commit
b1ae4ad260
@ -441,22 +441,24 @@ class Qwen2_5VLAttention(nn.Module):
|
|||||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||||
|
|
||||||
# calc maximum sequence length for any batch
|
|
||||||
query = query.contiguous()
|
|
||||||
key = key.contiguous()
|
|
||||||
value = value.contiguous()
|
|
||||||
causal = False
|
|
||||||
|
|
||||||
# execute sdpa
|
# execute sdpa
|
||||||
query = query.unsqueeze(0).transpose(1, 2)
|
causal = False
|
||||||
key = key.unsqueeze(0).transpose(1, 2)
|
query = query.transpose(0, 1)
|
||||||
value = value.unsqueeze(0).transpose(1, 2)
|
key = key.transpose(0, 1)
|
||||||
|
value = value.transpose(0, 1)
|
||||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
attention_mask[
|
||||||
|
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||||
|
] = True
|
||||||
attn_output = fsdpa_op(
|
attn_output = fsdpa_op(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_mask=None,
|
attn_mask=attention_mask,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
scale=None,
|
scale=None,
|
||||||
@ -464,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module):
|
|||||||
recompute_mode=None,
|
recompute_mode=None,
|
||||||
valid_sequence_lengths=None,
|
valid_sequence_lengths=None,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
attn_output = attn_output.transpose(0, 1)
|
||||||
|
|
||||||
# reshape output to original dimensions
|
# reshape output to original dimensions
|
||||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
|
@ -112,22 +112,24 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||||
|
|
||||||
# calc maximum sequence length for any batch
|
|
||||||
query = query.contiguous()
|
|
||||||
key = key.contiguous()
|
|
||||||
value = value.contiguous()
|
|
||||||
causal = False
|
|
||||||
|
|
||||||
# execute sdpa
|
# execute sdpa
|
||||||
query = query.unsqueeze(0).transpose(1, 2)
|
causal = False
|
||||||
key = key.unsqueeze(0).transpose(1, 2)
|
query = query.transpose(0, 1)
|
||||||
value = value.unsqueeze(0).transpose(1, 2)
|
key = key.transpose(0, 1)
|
||||||
|
value = value.transpose(0, 1)
|
||||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
attention_mask[
|
||||||
|
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||||
|
] = True
|
||||||
attn_output = fsdpa_op(
|
attn_output = fsdpa_op(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_mask=None,
|
attn_mask=attention_mask,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
scale=None,
|
scale=None,
|
||||||
@ -135,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
recompute_mode=None,
|
recompute_mode=None,
|
||||||
valid_sequence_lengths=None,
|
valid_sequence_lengths=None,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
attn_output = attn_output.transpose(0, 1)
|
||||||
# reshape output to original dimensions
|
# reshape output to original dimensions
|
||||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
|
Loading…
Reference in New Issue
Block a user