fix mllama oom if set batch_size > 8

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-11 23:18:59 -07:00
parent bba260912c
commit 027f293098

View File

@ -237,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
@ -705,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
@ -721,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,