fix: setting the rotary base from the config for the grouped query models.

This commit is contained in:
Nilabhra 2024-05-14 10:14:18 +04:00
parent 46ada47963
commit c41573c67c

View File

@ -481,6 +481,7 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()