Fixing config.n_head.

This commit is contained in:
Nicolas Patry 2024-07-02 17:01:25 +00:00
parent 24bbd7b822
commit e8ff76fd18
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674

View File

@ -884,14 +884,17 @@ class FlashCausalLM(Model):
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None) text_config = getattr(config, "text_config", None)
if text_config is not None: if text_config is not None:
config = text_config config = text_config
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
# Validation is done in the model itself # Validation is done in the model itself
num_heads = getattr(config, "num_key_value_heads", config.n_head)
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = config.num_key_value_heads num_kv_heads = getattr(config, "num_key_value_heads", None)
if num_kv_heads is None:
# Final overide for GPT2
num_kv_heads = config.n_head
self.num_kv_heads = num_kv_heads // self.process_group.size() self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads