mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Reload model_type when speculator is found.
This commit is contained in:
parent
6009dadee3
commit
dcb727c232
@ -158,6 +158,8 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
# Reload model type from parent.
|
||||
model_type = config_dict.get("model_type", None)
|
||||
is_local = Path(medusa_model_id).exists()
|
||||
if not is_local:
|
||||
medusa_config = hf_hub_download(
|
||||
@ -198,6 +200,8 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
# Reload model type from parent.
|
||||
model_type = config_dict.get("model_type", None)
|
||||
is_local = Path(mlp_model_id).exists()
|
||||
extension = ".safetensors"
|
||||
if not is_local:
|
||||
|
Loading…
Reference in New Issue
Block a user