From 3280d59f19e847317b35bedd7c9f050a835d7d03 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 19:08:09 +0000 Subject: [PATCH] fix: prefer original layernorm names for 180B --- .../models/custom_modeling/flash_rw_modeling.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0691da9b..fc002082 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module): prefix = f"{prefix}.h.{layer_id}" + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.{ln_prefix}", weights=weights, eps=config.layer_norm_epsilon, ) @@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module): # in the case no number of layer norms is provided, we default to 1 self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + if self.num_ln == 1: self.input_ln = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm",