fix: default use_qk_norm false in cohere

This commit is contained in:
drbh 2024-04-17 20:59:16 +00:00
parent 06c3d4b1ec
commit 91c653bac2

View File

@ -216,7 +216,9 @@ class FlashCohereAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = (
config.use_qk_norm if hasattr(config, "use_qk_norm") else False
)
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = CohereLayerNorm( self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",