Using head_dim as a fallback is necessary since it's a non standard

key in mistralConfig (as defined in transformers).
This commit is contained in:
Nicolas Patry 2024-07-23 12:56:37 +00:00
parent f0a5cb6c4e
commit ab62312d8c
No known key found for this signature in database
GPG Key ID: B154A218C20EBBCA

View File

@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig):
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.head_dim = head_dim or hidden_size // num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module):
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = config.head_dim
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = hidden_size // num_attention_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,