Support passing head_dim through config

This commit is contained in:
Shaltiel Shmidman 2024-07-19 12:16:54 +03:00
parent ba291dad9f
commit db7c519ee2

View File

@ -66,6 +66,7 @@ class MistralConfig(PretrainedConfig):
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=8, num_key_value_heads=8,
head_dim=None,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096 * 32, max_position_embeddings=4096 * 32,
initializer_range=0.02, initializer_range=0.02,
@ -87,6 +88,7 @@ class MistralConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.head_dim = head_dim or hidden_size // num_attention_heads
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
@ -117,7 +119,7 @@ class MistralAttention(torch.nn.Module):
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = config.head_dim
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
@ -146,15 +148,14 @@ class MistralAttention(torch.nn.Module):
bias=False, bias=False,
) )
head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load( self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value, query_key_value,
layer_id, layer_id,
["q_proj", "k_proj", "v_proj"], ["q_proj", "k_proj", "v_proj"],
sizes=[ sizes=[
head_size * config.num_attention_heads, self.head_size * config.num_attention_heads,
head_size * config.num_key_value_heads, self.head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads, self.head_size * config.num_key_value_heads,
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )