mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix: default num_ln_in_parallel_attn to one if not supplied (#2364)
This commit is contained in:
parent
5400c7155d
commit
db873be177
@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix: str, weights):
|
||||
super().__init__()
|
||||
self.num_ln = config.num_ln_in_parallel_attn
|
||||
# Falcon2 includes the number of layer norms in the config
|
||||
# in the case no number of layer norms is provided, we default to 1
|
||||
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||
|
||||
if self.num_ln == 1:
|
||||
self.input_ln = FastLayerNorm.load(
|
||||
|
Loading…
Reference in New Issue
Block a user