From cfc89bb3969ac275726e695f9445093ef544cc3b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 21 Mar 2024 09:49:58 +0100 Subject: [PATCH] faster --- .../models/custom_modeling/flash_cohere_modeling.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 15b7860f..985bbd8e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -244,7 +244,9 @@ class FlashCohereAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), reduce=False + ) class CohereMLP(nn.Module): @@ -282,7 +284,9 @@ class CohereMLP(nn.Module): def forward(self, hidden_states): gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False + ) class FlashCohereLayer(nn.Module): @@ -299,6 +303,7 @@ class FlashCohereLayer(nn.Module): weights=weights, eps=config.layer_norm_eps, ) + self.process_group = weights.process_group def forward( self, @@ -331,6 +336,9 @@ class FlashCohereLayer(nn.Module): mlp_output = self.mlp(normed_hidden_states) output = attn_output + mlp_output + if self.process_group.size() > 1: + torch.distributed.all_reduce(output, group=self.process_group) + return output, res