From dcb727c232d4ffa1ebed1abc16d346c8489b8e4c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 May 2024 08:13:55 +0000 Subject: [PATCH] Reload model_type when speculator is found. --- server/text_generation_server/models/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index eabc8293..e9761dfe 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -158,6 +158,8 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) is_local = Path(medusa_model_id).exists() if not is_local: medusa_config = hf_hub_download( @@ -198,6 +200,8 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) is_local = Path(mlp_model_id).exists() extension = ".safetensors" if not is_local: