diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index b0efbe1c..a80a86a7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -441,22 +441,24 @@ class Qwen2_5VLAttention(nn.Module): key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False - # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -464,7 +466,7 @@ class Qwen2_5VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 652def22..96acef31 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -112,22 +112,24 @@ class Qwen2VLAttention(nn.Module): key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False - # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -135,7 +137,7 @@ class Qwen2VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = self.proj(attn_output)