mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
Only n_heads / process_group.size() are necessary.
This commit is contained in:
parent
8d01848370
commit
8a4df6e181
@ -1001,7 +1001,7 @@ class FlashCausalLM(Model):
|
|||||||
config.sliding_window = None
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user