fix: setting the rotary base from the config for the grouped query models.

This commit is contained in:
Nilabhra 2024-05-14 10:14:18 +04:00
parent 46ada47963
commit c41573c67c

View File

@ -481,6 +481,7 @@ class FlashRWLayer(nn.Module):
return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
@ -602,7 +603,7 @@ class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
self.self_attention = FlashRWLargeAttention(