mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Using head_dim
as a fallback is necessary since it's a non standard
key in mistralConfig (as defined in transformers).
This commit is contained in:
parent
f0a5cb6c4e
commit
ab62312d8c
@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig):
|
|||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=8,
|
num_key_value_heads=8,
|
||||||
head_dim=None,
|
|
||||||
hidden_act="silu",
|
hidden_act="silu",
|
||||||
max_position_embeddings=4096 * 32,
|
max_position_embeddings=4096 * 32,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig):
|
|||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
# for backward compatibility
|
||||||
if num_key_value_heads is None:
|
if num_key_value_heads is None:
|
||||||
@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = config.head_dim
|
if hasattr(config, "head_dim"):
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = hidden_size // num_attention_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
|
Loading…
Reference in New Issue
Block a user