diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index be2e9bc7..a4424dc9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -481,6 +481,7 @@ class FlashRWLayer(nn.Module): return mlp_output, residual + class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -602,7 +603,7 @@ class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.self_attention = FlashRWLargeAttention(