mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Support passing head_dim through config
This commit is contained in:
parent
ba291dad9f
commit
db7c519ee2
@ -66,6 +66,7 @@ class MistralConfig(PretrainedConfig):
|
|||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=8,
|
num_key_value_heads=8,
|
||||||
|
head_dim=None,
|
||||||
hidden_act="silu",
|
hidden_act="silu",
|
||||||
max_position_embeddings=4096 * 32,
|
max_position_embeddings=4096 * 32,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
@ -87,6 +88,7 @@ class MistralConfig(PretrainedConfig):
|
|||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
|
||||||
# for backward compatibility
|
# for backward compatibility
|
||||||
if num_key_value_heads is None:
|
if num_key_value_heads is None:
|
||||||
@ -117,7 +119,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = config.head_dim
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
@ -146,15 +148,14 @@ class MistralAttention(torch.nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
query_key_value,
|
query_key_value,
|
||||||
layer_id,
|
layer_id,
|
||||||
["q_proj", "k_proj", "v_proj"],
|
["q_proj", "k_proj", "v_proj"],
|
||||||
sizes=[
|
sizes=[
|
||||||
head_size * config.num_attention_heads,
|
self.head_size * config.num_attention_heads,
|
||||||
head_size * config.num_key_value_heads,
|
self.head_size * config.num_key_value_heads,
|
||||||
head_size * config.num_key_value_heads,
|
self.head_size * config.num_key_value_heads,
|
||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user