From 3ccde430d95157a10a6272950e3b6a6031db6b53 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 15:25:30 -0400 Subject: [PATCH] fix: prefer original layernorm names for 180B (#2365) --- .../models/custom_modeling/flash_rw_modeling.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0691da9b..fc002082 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module): prefix = f"{prefix}.h.{layer_id}" + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.{ln_prefix}", weights=weights, eps=config.layer_norm_epsilon, ) @@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module): # in the case no number of layer norms is provided, we default to 1 self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + if self.num_ln == 1: self.input_ln = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm",