From 5e8db7c14f7ca2709cd108fb68c41ed27bb2ed86 Mon Sep 17 00:00:00 2001 From: Nilabhra Date: Mon, 15 Apr 2024 13:52:20 +0400 Subject: [PATCH] add: support for falcon-10B architecture. --- .../custom_modeling/flash_rw_modeling.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index a2236a3d..bb53ee20 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -7,8 +7,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention -from text_generation_server.layers import ( +from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -480,6 +479,44 @@ class FlashRWLayer(nn.Module): return mlp_output, residual +class FlashRWLayerNorm(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_ln = config.num_ln_in_parallel_attn + + if self.num_ln == 1: + self.input_ln = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + elif self.num_ln == 2: + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) + else: + raise ValueError("Number of layer norms can either be 1 or 2.") + + def forward( + self, + hidden_states, + residual, + ): + if self.num_ln == 1: + ln_hidden_states, residual = self.input_ln(hidden_states, residual) + return ln_hidden_states, ln_hidden_states, residual + elif self.num_ln == 2: + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) + return ln_attn, ln_mlp, residual + class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix, weights):