mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
fix mllama oom if set batch_size > 8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
bba260912c
commit
027f293098
@ -237,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
||||
@ -705,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
|
||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||
|
||||
causal = False
|
||||
# logger.info(
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
@ -721,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
value_states,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
|
Loading…
Reference in New Issue
Block a user