diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 77125f53..5fc16fe3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -884,14 +884,17 @@ class FlashCausalLM(Model): model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) + # VLM models define the config we care about in their text_config text_config = getattr(config, "text_config", None) if text_config is not None: config = text_config self.num_layers = config.num_hidden_layers # Validation is done in the model itself - num_heads = getattr(config, "num_key_value_heads", config.n_head) if num_kv_heads is None: - num_kv_heads = config.num_key_value_heads + num_kv_heads = getattr(config, "num_key_value_heads", None) + if num_kv_heads is None: + # Final overide for GPT2 + num_kv_heads = config.n_head self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads