diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 365c209e..a11ab1d3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -66,7 +66,6 @@ class MistralConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, - head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, @@ -88,7 +87,6 @@ class MistralConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window - self.head_dim = head_dim or hidden_size // num_attention_heads # for backward compatibility if num_key_value_heads is None: @@ -119,7 +117,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = config.head_dim + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = hidden_size // num_attention_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config,