mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Using head_dim
as a fallback is necessary since it's a non standard
key in mistralConfig (as defined in transformers).
This commit is contained in:
parent
f0a5cb6c4e
commit
ab62312d8c
@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig):
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
head_dim=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module):
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if hasattr(config, "head_dim"):
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = hidden_size // num_attention_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
|
Loading…
Reference in New Issue
Block a user