From ab62312d8cad4c4c68a61102065d27ecb92b2b1b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Jul 2024 12:56:37 +0000 Subject: [PATCH] Using `head_dim` as a fallback is necessary since it's a non standard key in mistralConfig (as defined in transformers). --- .../models/custom_modeling/flash_mistral_modeling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 365c209e..a11ab1d3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, - head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, @@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window - self.head_dim = head_dim or hidden_size // num_attention_heads # for backward compatibility if num_key_value_heads is None: @@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = config.head_dim + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = hidden_size // num_attention_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config,