diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index ea3129f9..fe6d137b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -237,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module): key = key.transpose(1, 2) value = value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) @@ -705,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module): # key_states = key_states.repeat(1, self.num_key_value_groups, 1) # value_states = value_states.repeat(1, self.num_key_value_groups, 1) - - causal = False # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) @@ -721,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module): value_states, attn_mask=None, dropout_p=0.0, - is_causal=causal, + is_causal=False, scale=None, softmax_mode="None", recompute_mode=None,