diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py
index 2e638d77..ac07aa98 100644
--- a/server/text_generation_server/models/flash_neox_modeling.py
+++ b/server/text_generation_server/models/flash_neox_modeling.py
@@ -1,3 +1,23 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
 import torch.distributed
 
@@ -16,6 +36,42 @@ import dropout_layer_norm
 from flash_attn.layers.rotary import RotaryEmbedding
 
 
+class FastLayerNorm(nn.LayerNorm):
+    def forward(self, hidden_states, residual=None):
+        if hidden_states.shape[-1] > 6144:
+            if residual is not None:
+                hidden_states += residual
+            residual = hidden_states
+
+            return super(FastLayerNorm, self).forward(hidden_states), residual
+        else:
+            (
+                normed_hidden_states,
+                residual,
+                *rest,
+            ) = dropout_layer_norm.dropout_add_ln_fwd(
+                hidden_states,
+                residual,
+                self.weight,
+                self.bias,
+                None,
+                None,
+                None,
+                None,
+                0.0,
+                self.eps,
+                1.0,
+                0,
+                None,
+                False,
+                False,
+            )
+            if residual is None:
+                residual = hidden_states
+
+            return normed_hidden_states, residual
+
+
 class FastLinear(nn.Linear):
     def __init__(
         self,
@@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
             dtype=dtype,
         )
 
-    def forward(self, input):
-        return super(TensorParallelColumnLinear, self).forward(input)
-
 
 class TensorParallelRowLinear(FastLinear):
     def __init__(
@@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
         in_features,
         out_features,
         process_group: torch.distributed.ProcessGroup,
+        reduce=True,
         bias=True,
         device=None,
         dtype=None,
     ):
         self.process_group = process_group
         self.tp_world_size = process_group.size()
+        self.reduce = reduce
         assert in_features % self.tp_world_size == 0
         in_features = in_features // self.tp_world_size
 
@@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
 
     def forward(self, input: torch.Tensor) -> torch.Tensor:
         out = super(TensorParallelRowLinear, self).forward(input)
-        torch.distributed.all_reduce(out, group=self.process_group)
+        if self.reduce:
+            torch.distributed.all_reduce(out, group=self.process_group)
 
         return out
 
@@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
 
 class FlashNeoxAttention(torch.nn.Module):
     def __init__(
-        self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
+        self,
+        num_heads,
+        hidden_size,
+        rotary_pct,
+        rotary_emb_base,
+        process_group=None,
+        reduce=True,
     ):
         super().__init__()
         self.num_heads = num_heads
@@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
                 process_group=process_group,
             )
             self.dense = TensorParallelRowLinear(
-                hidden_size,
-                hidden_size,
-                process_group=process_group,
+                hidden_size, hidden_size, process_group=process_group, reduce=reduce
             )
 
     def shuffle_qkv_dims(self):
@@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
 
 
 class FlashMLP(nn.Module):
-    def __init__(self, act, hidden_size, intermediate_size, process_group=None):
+    def __init__(
+        self, act, hidden_size, intermediate_size, process_group=None, reduce=True
+    ):
         super().__init__()
         self.act = (
             ACT2FN[act]
@@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
                 intermediate_size,
                 hidden_size,
                 process_group=process_group,
+                reduce=reduce,
             )
         self.process_group = process_group
 
@@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
     ):
         super().__init__()
         self.use_parallel_residual = use_parallel_residual
-        self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
-        self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+        self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
+        self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
         self.attention = FlashNeoxAttention(
-            num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
+            num_heads,
+            hidden_size,
+            rotary_pct,
+            rotary_emb_base,
+            process_group,
+            reduce=not use_parallel_residual,
         )
-        self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
+        self.mlp = FlashMLP(
+            act,
+            hidden_size,
+            intermediate_size,
+            process_group,
+            reduce=not use_parallel_residual,
+        )
+        self.process_group = process_group
 
     def forward(
         self,
@@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
         cu_seqlens_q,
     ):
         if self.use_parallel_residual:
-            # faster input layer norm
-            ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
-                hidden_states,
-                None,
-                self.input_layernorm.weight,
-                self.input_layernorm.bias,
-                None,
-                None,
-                None,
-                None,
-                0.0,
-                self.input_layernorm.eps,
-                1.0,
-                0,
-                None,
-                False,
-                False,
-            )
+            ln1_hidden_states, _ = self.input_layernorm(hidden_states)
 
             attn_output = self.attention(
                 ln1_hidden_states,
@@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
                 cu_seqlens_q,
             )
 
-            # faster post attention layer norm
-            ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
-                hidden_states,
-                None,
-                self.post_attention_layernorm.weight,
-                self.post_attention_layernorm.bias,
-                None,
-                None,
-                None,
-                None,
-                0.0,
-                self.post_attention_layernorm.eps,
-                1.0,
-                0,
-                None,
-                False,
-                False,
-            )
+            ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
 
             mlp_output = self.mlp(ln2_hidden_states)
-            return mlp_output + attn_output + hidden_states, None
+            intermediate = mlp_output + attn_output
+
+            # Only reduce once and after the addition instead of once per layer
+            if self.process_group is not None:
+                torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+            return intermediate + hidden_states, None
         else:
-            # faster input layer norm
-            hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
-                hidden_states,
-                residual,
-                self.input_layernorm.weight,
-                self.input_layernorm.bias,
-                None,
-                None,
-                None,
-                None,
-                0.0,
-                self.input_layernorm.eps,
-                1.0,
-                0,
-                None,
-                False,
-                False,
-            )
+            hidden_states, residual = self.input_layernorm(hidden_states, residual)
 
             hidden_states = self.attention(
                 hidden_states,
@@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
                 cu_seqlens_q,
             )
 
-            # faster post attention layer norm
-            hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
-                hidden_states,
-                residual,
-                self.post_attention_layernorm.weight,
-                self.post_attention_layernorm.bias,
-                None,
-                None,
-                None,
-                None,
-                0.0,
-                self.post_attention_layernorm.eps,
-                1.0,
-                0,
-                None,
-                False,
-                False,
+            hidden_states, residual = self.post_attention_layernorm(
+                hidden_states, residual
             )
 
             mlp_output = self.mlp(hidden_states)
@@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
                 for _ in range(config.num_hidden_layers)
             ]
         )
-        self.final_layer_norm = nn.LayerNorm(
+        self.final_layer_norm = FastLayerNorm(
             config.hidden_size, eps=config.layer_norm_eps
         )
 
@@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
                 cu_seqlens_q,
             )
 
-        # Faster final layer norm
-        hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
-            hidden_states,
-            residual,
-            self.final_layer_norm.weight,
-            self.final_layer_norm.bias,
-            None,
-            None,
-            None,
-            None,
-            0.0,
-            self.final_layer_norm.eps,
-            1.0,
-            0,
-            None,
-            False,
-            False,
-        )
+        hidden_states, _ = self.final_layer_norm(hidden_states, residual)
 
         return hidden_states, past_key_values