Using head_dim as a fallback is necessary since it's a non standard

key in mistralConfig (as defined in transformers).
This commit is contained in:
Nicolas Patry 2024-07-23 12:56:37 +00:00
parent f0a5cb6c4e
commit ab62312d8c
No known key found for this signature in database
GPG Key ID: B154A218C20EBBCA

View File

@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig):
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=8, num_key_value_heads=8,
head_dim=None,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096 * 32, max_position_embeddings=4096 * 32,
initializer_range=0.02, initializer_range=0.02,
@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.head_dim = head_dim or hidden_size // num_attention_heads
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module):
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = config.head_dim if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = hidden_size // num_attention_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,