mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing config.n_head
.
This commit is contained in:
parent
24bbd7b822
commit
e8ff76fd18
@ -884,14 +884,17 @@ class FlashCausalLM(Model):
|
|||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
# VLM models define the config we care about in their text_config
|
||||||
text_config = getattr(config, "text_config", None)
|
text_config = getattr(config, "text_config", None)
|
||||||
if text_config is not None:
|
if text_config is not None:
|
||||||
config = text_config
|
config = text_config
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
num_heads = getattr(config, "num_key_value_heads", config.n_head)
|
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = config.num_key_value_heads
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
if num_kv_heads is None:
|
||||||
|
# Final overide for GPT2
|
||||||
|
num_kv_heads = config.n_head
|
||||||
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user