mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix mllama oom if set batch_size > 8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
bba260912c
commit
027f293098
@ -237,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
|
|||||||
key = key.transpose(1, 2)
|
key = key.transpose(1, 2)
|
||||||
value = value.transpose(1, 2)
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
query, key, value, attn_mask=attention_mask
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||||
|
|
||||||
@ -705,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
|
|
||||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
|
||||||
causal = False
|
|
||||||
# logger.info(
|
# logger.info(
|
||||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
# )
|
# )
|
||||||
@ -721,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
value_states,
|
value_states,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=causal,
|
is_causal=False,
|
||||||
scale=None,
|
scale=None,
|
||||||
softmax_mode="None",
|
softmax_mode="None",
|
||||||
recompute_mode=None,
|
recompute_mode=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user