fix: default num_ln_in_parallel_attn to one if not supplied (#2364)

This commit is contained in:
drbh 2024-08-06 13:33:22 -04:00 committed by yuanwu
parent 5400c7155d
commit db873be177

View File

@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
# Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(