This commit is contained in:
OlivierDehaene 2024-03-21 09:49:58 +01:00
parent 56296cc43c
commit cfc89bb396

View File

@ -244,7 +244,9 @@ class FlashCohereAttention(torch.nn.Module):
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), reduce=False
)
class CohereMLP(nn.Module):
@ -282,7 +284,9 @@ class CohereMLP(nn.Module):
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
)
class FlashCohereLayer(nn.Module):
@ -299,6 +303,7 @@ class FlashCohereLayer(nn.Module):
weights=weights,
eps=config.layer_norm_eps,
)
self.process_group = weights.process_group
def forward(
self,
@ -331,6 +336,9 @@ class FlashCohereLayer(nn.Module):
mlp_output = self.mlp(normed_hidden_states)
output = attn_output + mlp_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(output, group=self.process_group)
return output, res