Reload model_type when speculator is found.

This commit is contained in:
Nicolas Patry 2024-05-14 08:13:55 +00:00
parent 6009dadee3
commit dcb727c232

View File

@ -158,6 +158,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists() is_local = Path(medusa_model_id).exists()
if not is_local: if not is_local:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
@ -198,6 +200,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists() is_local = Path(mlp_model_id).exists()
extension = ".safetensors" extension = ".safetensors"
if not is_local: if not is_local: