From db7c519ee2f9934c2f3ead13facd47ce66df1ef0 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Fri, 19 Jul 2024 12:16:54 +0300 Subject: [PATCH] Support passing head_dim through config --- .../models/custom_modeling/flash_mistral_modeling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8028dbe8..365c209e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -66,6 +66,7 @@ class MistralConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, + head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, @@ -87,6 +88,7 @@ class MistralConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window + self.head_dim = head_dim or hidden_size // num_attention_heads # for backward compatibility if num_key_value_heads is None: @@ -117,7 +119,7 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + self.head_size = config.head_dim self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -146,15 +148,14 @@ class MistralAttention(torch.nn.Module): bias=False, ) - head_size = config.hidden_size // config.num_attention_heads self.query_key_value = TensorParallelMultiAdapterLinear.load( query_key_value, layer_id, ["q_proj", "k_proj", "v_proj"], sizes=[ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, ], process_group=weights.process_group, )